• Docs >
  • Custom Checkpointing IO
Shortcuts

Custom Checkpointing IO

Warning

The Checkpoint IO API is experimental and subject to change.

Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO. This encapsulates the save/load logic that is managed by the TrainingTypePlugin.

CheckpointIO can be extended to include your custom save/load functionality to and from a path. The CheckpointIO object can be passed to either a Trainer object or a TrainingTypePlugin as shown below.

from pathlib import Path
from typing import Any, Dict, Optional, Union

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO, SingleDevicePlugin


class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(
        self, checkpoint: Dict[str, Any], path: Union[str, Path], storage_options: Optional[Any] = None
    ) -> None:
        ...

    def load_checkpoint(self, path: Union[str, Path], storage_options: Optional[Any] = None) -> Dict[str, Any]:
        ...


custom_checkpoint_io = CustomCheckpointIO()

# Pass into the Trainer object
model = MyModel()
trainer = Trainer(
    plugins=[custom_checkpoint_io],
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

# pass into TrainingTypePlugin
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
    plugins=SingleDevicePlugin(device, checkpoint_io=custom_checkpoint_io),
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

Note

Some TrainingTypePlugins do not support custom CheckpointIO as as checkpointing logic is not modifiable.

You are viewing an outdated version of PyTorch Lightning Docs

Click here to view the latest version→