• Docs >
  • Checkpointing (expert)
Shortcuts

Checkpointing (expert)

Writing your own Checkpoint class

We provide Checkpoint class, for easier subclassing. Users may want to subclass this class in case of writing custom ModelCheckpoint callback, so that the Trainer recognizes the custom class as a checkpointing callback.

Customize Checkpointing

Warning

The Checkpoint IO API is experimental and subject to change.

Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO. This encapsulates the save/load logic that is managed by the Strategy. CheckpointIO is different from on_save_checkpoint() and on_load_checkpoint() methods as it determines how the checkpoint is saved/loaded to storage rather than what’s saved in the checkpoint.

TODO: I don’t understand this…

Built-in Checkpoint IO Plugins

Built-in Checkpoint IO Plugins

Plugin

Description

TorchCheckpointIO

CheckpointIO that utilizes torch.save() and torch.load() to save and load checkpoints respectively, common for most use cases.

XLACheckpointIO

CheckpointIO that utilizes xm.save() to save checkpoints for TPU training strategies.

HPUCheckpointIO

CheckpointIO to save checkpoints for HPU training strategies.

AsyncCheckpointIO

AsyncCheckpointIO enables saving the checkpoints asynchronously in a thread.

Custom Checkpoint IO Plugin

CheckpointIO can be extended to include your custom save/load functionality to and from a path. The CheckpointIO object can be passed to either a Trainer directly or a Strategy as shown below:

from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.strategies import SingleDeviceStrategy


class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint, path, storage_options=None):
        ...

    def load_checkpoint(self, path, storage_options=None):
        ...

    def remove_checkpoint(self, path):
        ...


custom_checkpoint_io = CustomCheckpointIO()

# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
    plugins=[custom_checkpoint_io],
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
    strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

Note

Some TrainingTypePlugins like DeepSpeedStrategy do not support custom CheckpointIO as checkpointing logic is not modifiable.

Asynchronous Checkpointing

Warning

This is currently an experimental plugin/feature and API changes are to be expected.

To enable saving the checkpoints asynchronously without blocking your training, you can configure AsyncCheckpointIO plugin to Trainer.

from pytorch_lightning.plugins.io import AsyncCheckpointIO


async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])

It uses its base CheckpointIO plugin’s saving logic to save the checkpoint but performs this operation asynchronously. By default, this base CheckpointIO will be set-up for you and all you need to provide is the AsyncCheckpointIO instance to the Trainer. But if you want the plugin to use your own custom base CheckpointIO and want the base to behave asynchronously, pass it as an argument while initializing AsyncCheckpointIO.

from pytorch_lightning.plugins.io import AsyncCheckpointIO

base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.