Checkpointing (advanced)¶
Cloud checkpoints¶
Lightning is integrated with the major remote file systems including local filesystems and several cloud storage providers such as S3 on AWS, GCS on Google Cloud, or ADL on Azure.
PyTorch Lightning uses fsspec internally to handle all filesystem operations.
Save a cloud checkpoint¶
To save to a remote filesystem, prepend a protocol like “s3:/” to the root_dir used for writing and reading model data.
# `default_root_dir` is the default path used for logs and checkpoints
trainer = Trainer(default_root_dir="s3://my_bucket/data/")
trainer.fit(model)
Resume training from a cloud checkpoint¶
To resume training from a cloud checkpoint use a cloud url.
trainer = Trainer(default_root_dir=tmpdir, max_steps=3)
trainer.fit(model, ckpt_path="s3://my_bucket/ckpts/classifier.ckpt")
PyTorch Lightning uses fsspec internally to handle all filesystem operations.
Modularize your checkpoints¶
Checkpoints can also save the state of datamodules and callbacks.
Modify a checkpoint anywhere¶
When you need to change the components of a checkpoint before saving or loading, use the on_save_checkpoint()
and on_load_checkpoint()
of your LightningModule
.
class LitModel(pl.LightningModule):
def on_save_checkpoint(self, checkpoint):
checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object
def on_load_checkpoint(self, checkpoint):
my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]
Use the above approach when you need to couple this behavior to your LightningModule for reproducibility reasons. Otherwise, Callbacks also have the on_save_checkpoint()
and on_load_checkpoint()
which you should use instead:
class LitCallback(pl.Callback):
def on_save_checkpoint(self, checkpoint):
checkpoint["something_cool_i_want_to_save"] = my_cool_pickable_object
def on_load_checkpoint(self, checkpoint):
my_cool_pickable_object = checkpoint["something_cool_i_want_to_save"]