Custom Checkpointing IOΒΆ
Warning
The Checkpoint IO API is experimental and subject to change.
Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO
. This encapsulates the save/load logic
that is managed by the TrainingTypePlugin
.
CheckpointIO
can be extended to include your custom save/load functionality to and from a path. The CheckpointIO
object can be passed to either a Trainer
object or a TrainingTypePlugin
as shown below.
from pathlib import Path
from typing import Any, Dict, Optional, Union
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(
self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
) -> None:
...
def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]:
...
custom_checkpoint_io = CustomCheckpointIO()
# Pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
# pass into TrainingTypePlugin
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
Note
Some TrainingTypePlugins
do not support custom CheckpointIO
as as checkpointing logic is not modifiable.