Distributed checkpoints (expert)¶
Writing your own Checkpoint class¶
We provide Checkpoint
class, for easier subclassing. Users may want to subclass this class in case of writing custom ModelCheckpoint
callback, so that the Trainer
recognizes the custom class as a checkpointing callback.
Customize Checkpointing¶
Warning
This is an experimental feature.
Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO
. This encapsulates the save/load logic
that is managed by the Strategy
. CheckpointIO
is different from on_save_checkpoint()
and on_load_checkpoint()
methods as it determines how the checkpoint is saved/loaded to storage rather than
what’s saved in the checkpoint.
TODO: I don’t understand this…
Built-in Checkpoint IO Plugins¶
Plugin |
Description |
---|---|
CheckpointIO that utilizes |
|
CheckpointIO that utilizes |
|
|
Custom Checkpoint IO Plugin¶
CheckpointIO
can be extended to include your custom save/load functionality to and from a path. The CheckpointIO
object can be passed to either a Trainer
directly or a Strategy
as shown below:
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.plugins import CheckpointIO
from lightning.pytorch.strategies import SingleDeviceStrategy
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
...
def load_checkpoint(self, path, storage_options=None):
...
def remove_checkpoint(self, path):
...
custom_checkpoint_io = CustomCheckpointIO()
# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
Note
Some Strategy``s like ``DeepSpeedStrategy
do not support custom CheckpointIO
as checkpointing logic is not modifiable.
Asynchronous Checkpointing¶
Warning
This is an experimental feature.
To enable saving the checkpoints asynchronously without blocking your training, you can configure
AsyncCheckpointIO
plugin to Trainer
.
from lightning.pytorch.plugins.io import AsyncCheckpointIO
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
It uses its base CheckpointIO
plugin’s saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base CheckpointIO
will be set-up for you and all you need to provide is the AsyncCheckpointIO
instance to the Trainer
.
But if you want the plugin to use your own custom base CheckpointIO
and want the base to behave asynchronously, pass it as an argument while initializing AsyncCheckpointIO
.
from lightning.pytorch.plugins.io import AsyncCheckpointIO
base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])