• Docs >
  • Distributed checkpoints (expert)
Shortcuts

Distributed checkpoints (expert)

Writing your own Checkpoint class

We provide Checkpoint class, for easier subclassing. Users may want to subclass this class in case of writing custom ModelCheckpoint callback, so that the Trainer recognizes the custom class as a checkpointing callback.

Customize Checkpointing

Warning

This is an experimental feature.

Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO. This encapsulates the save/load logic that is managed by the Strategy. CheckpointIO is different from on_save_checkpoint() and on_load_checkpoint() methods as it determines how the checkpoint is saved/loaded to storage rather than what’s saved in the checkpoint.

TODO: I don’t understand this…

Built-in Checkpoint IO Plugins

Built-in Checkpoint IO Plugins

Plugin

Description

TorchCheckpointIO

CheckpointIO that utilizes torch.save() and torch.load() to save and load checkpoints respectively, common for most use cases.

XLACheckpointIO

CheckpointIO that utilizes xm.save to save checkpoints for TPU training strategies.

AsyncCheckpointIO

AsyncCheckpointIO enables saving the checkpoints asynchronously in a thread.

Custom Checkpoint IO Plugin

CheckpointIO can be extended to include your custom save/load functionality to and from a path. The CheckpointIO object can be passed to either a Trainer directly or a Strategy as shown below:

from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.plugins import CheckpointIO
from lightning.pytorch.strategies import SingleDeviceStrategy


class CustomCheckpointIO(CheckpointIO):
    def save_checkpoint(self, checkpoint, path, storage_options=None):
        ...

    def load_checkpoint(self, path, storage_options=None):
        ...

    def remove_checkpoint(self, path):
        ...


custom_checkpoint_io = CustomCheckpointIO()

# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
    plugins=[custom_checkpoint_io],
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
    strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
    callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)

Note

Some Strategy``s like ``DeepSpeedStrategy do not support custom CheckpointIO as checkpointing logic is not modifiable.

Asynchronous Checkpointing

Warning

This is an experimental feature.

To enable saving the checkpoints asynchronously without blocking your training, you can configure AsyncCheckpointIO plugin to Trainer.

from lightning.pytorch.plugins.io import AsyncCheckpointIO


async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])

It uses its base CheckpointIO plugin’s saving logic to save the checkpoint but performs this operation asynchronously. By default, this base CheckpointIO will be set-up for you and all you need to provide is the AsyncCheckpointIO instance to the Trainer. But if you want the plugin to use your own custom base CheckpointIO and want the base to behave asynchronously, pass it as an argument while initializing AsyncCheckpointIO.

from lightning.pytorch.plugins.io import AsyncCheckpointIO

base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])