Checkpointing (expert)¶
Writing your own Checkpoint class¶
We provide Checkpoint
class, for easier subclassing. Users may want to subclass this class in case of writing custom ModelCheckpoint
callback, so that the Trainer
recognizes the custom class as a checkpointing callback.
Customize Checkpointing¶
Warning
The Checkpoint IO API is experimental and subject to change.
Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO
. This encapsulates the save/load logic
that is managed by the Strategy
. CheckpointIO
is different from on_save_checkpoint()
and on_load_checkpoint()
methods as it determines how the checkpoint is saved/loaded to storage rather than
what’s saved in the checkpoint.
TODO: I don’t understand this…
Built-in Checkpoint IO Plugins¶
Plugin |
Description |
---|---|
CheckpointIO that utilizes |
|
CheckpointIO that utilizes |
|
CheckpointIO to save checkpoints for HPU training strategies. |
|
|
Custom Checkpoint IO Plugin¶
CheckpointIO
can be extended to include your custom save/load functionality to and from a path. The CheckpointIO
object can be passed to either a Trainer
directly or a Strategy
as shown below:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.strategies import SingleDeviceStrategy
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
...
def load_checkpoint(self, path, storage_options=None):
...
def remove_checkpoint(self, path):
...
custom_checkpoint_io = CustomCheckpointIO()
# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
Note
Some TrainingTypePlugins
like DeepSpeedStrategy
do not support custom CheckpointIO
as checkpointing logic is not modifiable.
Asynchronous Checkpointing¶
Warning
This is currently an experimental plugin/feature and API changes are to be expected.
To enable saving the checkpoints asynchronously without blocking your training, you can configure
AsyncCheckpointIO
plugin to Trainer
.
from pytorch_lightning.plugins.io import AsyncCheckpointIO
async_ckpt_io = AsyncCheckpointIO()
trainer = Trainer(plugins=[async_ckpt_io])
It uses its base CheckpointIO
plugin’s saving logic to save the checkpoint but performs this operation asynchronously.
By default, this base CheckpointIO
will be set-up for you and all you need to provide is the AsyncCheckpointIO
instance to the Trainer
.
But if you want the plugin to use your own custom base CheckpointIO
and want the base to behave asynchronously, pass it as an argument while initializing AsyncCheckpointIO
.
from pytorch_lightning.plugins.io import AsyncCheckpointIO
base_ckpt_io = MyCustomCheckpointIO()
async_ckpt_io = AsyncCheckpointIO(checkpoint_io=base_ckpt_io)
trainer = Trainer(plugins=[async_ckpt_io])