Checkpointing¶
Lightning provides functions to save and load checkpoints.
Checkpointing your training allows you to resume a training process in case it was interrupted, fine-tune a model or use a pre-trained model for inference without having to retrain the model.
Checkpoint Contents¶
A Lightning checkpoint has everything needed to restore a training session including:
16-bit scaling factor (if using 16-bit precision training)
Current epoch
Global step
LightningModule’s state_dict
State of all optimizers
State of all learning rate schedulers
State of all callbacks (for stateful callbacks)
State of datamodule (for stateful datamodules)
The hyperparameters used for that model if passed in as hparams (Argparse.Namespace)
State of Loops (if using Fault-Tolerant training)
Individual Component States¶
Each component can save and load its state by implementing the PyTorch state_dict
, load_state_dict
stateful protocol.
For details on implementing your own stateful callbacks and datamodules, refer to the individual docs pages at callbacks and datamodules.
Operating on Global Checkpoint Component States¶
If you need to operate on the global component state (i.e. the entire checkpoint dictionary), you can read/add/delete/modify custom states in your checkpoints before they are being saved or loaded.
For this you can override on_save_checkpoint()
and on_load_checkpoint()
in your LightningModule
or on_save_checkpoint()
and on_load_checkpoint()
methods in your Callback
.
Checkpoint Saving¶
Automatic Saving¶
Lightning automatically saves a checkpoint for you in your current working directory, with the state of your last training epoch. This makes sure you can resume training in case it was interrupted.
To change the checkpoint path pass in:
# saves checkpoints to '/your/path/to/save/checkpoints' at every epoch end
trainer = Trainer(default_root_dir="/your/path/to/save/checkpoints")
You can retrieve the checkpoint after training by calling:
checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss")
trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)
checkpoint_callback.best_model_path
Disabling Checkpoints¶
You can disable checkpointing by passing:
trainer = Trainer(enable_checkpointing=False)
Manual Saving¶
You can manually save checkpoints and restore your model from the checkpointed state using save_checkpoint()
and load_from_checkpoint()
.
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
new_model = MyLightningModule.load_from_checkpoint(checkpoint_path="example.ckpt")
Manual Saving with Distributed Training Strategies¶
Lightning also handles strategies where multiple processes are running, such as DDP. For example, when using the DDP strategy our training script is running across multiple devices at the same time. Lightning automatically ensures that the model is saved only on the main process, whilst other processes do not interfere with saving checkpoints. This requires no code changes as seen below:
trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
trainer.save_checkpoint("example.ckpt")
Not using save_checkpoint()
can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer’s save functionality.
If using custom saving functions cannot be avoided, we recommend using the rank_zero_only()
decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won’t work when using
model parallel distributed strategies such as deepspeed or sharded training.
Checkpointing Hyperparameters¶
The Lightning checkpoint also saves the arguments passed into the LightningModule init
under the "hyper_parameters"
key in the checkpoint.
class MyLightningModule(LightningModule):
def __init__(self, learning_rate, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
# all init args were saved to the checkpoint
checkpoint = torch.load(CKPT_PATH)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value}
Checkpoint Loading¶
To load a model along with its weights and hyperparameters use the following method:
model = MyLightningModule.load_from_checkpoint(PATH)
print(model.learning_rate)
# prints the learning_rate you used in this checkpoint
model.eval()
y_hat = model(x)
But if you don’t want to use the hyperparameters saved in the checkpoint, pass in your own here:
class LitModel(LightningModule):
def __init__(self, in_dim, out_dim):
super().__init__()
self.save_hyperparameters()
self.l1 = nn.Linear(self.hparams.in_dim, self.hparams.out_dim)
you can restore the model like this
# if you train and save the model like this it will use these values when loading
# the weights. But you can overwrite this
LitModel(in_dim=32, out_dim=10)
# uses in_dim=32, out_dim=10
model = LitModel.load_from_checkpoint(PATH)
# uses in_dim=128, out_dim=10
model = LitModel.load_from_checkpoint(PATH, in_dim=128, out_dim=10)
Restoring Training State¶
If you don’t just want to load weights, but instead restore the full training, do the following:
model = LitModel()
trainer = Trainer()
# automatically restores model, epoch, step, LR schedulers, apex, etc...
trainer.fit(model, ckpt_path="some/path/to/my_checkpoint.ckpt")
Conditional Checkpointing (ModelCheckpoint)¶
The ModelCheckpoint
callback allows you to configure when/which/what/where checkpointing should happen. It follows the normal Callback hook structure so you can
hack it around/override its methods for your use-cases as well. Following are some of the common use-cases along with the arguments you need to specify to configure it:
How does it work?¶
ModelCheckpoint
helps cover the following cases from WH-Family:
When¶
When using iterative training which doesn’t have an epoch, you can checkpoint at every
N
training steps by specifyingevery_n_training_steps=N
.You can also control the interval of epochs between checkpoints using
every_n_epochs
between checkpoints, to avoid slowdowns.You can checkpoint at a regular time interval using
train_time_interval
argument independent of the steps or epochs.In case you are monitoring a training metrics, we’d suggest using
save_on_train_epoch_end=True
to ensure the required metric is being accumulated correctly for creating a checkpoint.
Which¶
You can save the last checkpoint when training ends using
save_last
argument.You can save top-K and last-K checkpoints by configuring the
monitor
andsave_top_k
argument.
from pytorch_lightning.callbacks import ModelCheckpoint # saves top-K checkpoints based on "val_loss" metric checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_loss", mode="min", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", ) # saves last-K checkpoints based on "global_step" metric # make sure you log it inside your LightningModule checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="global_step", mode="max", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{global_step}", )
You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
from pytorch_lightning.callbacks import ModelCheckpoint class LitAutoEncoder(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) # 1. calculate loss loss = F.cross_entropy(y_hat, y) # 2. log val_loss self.log("val_loss", loss) # 3. Init ModelCheckpoint callback, monitoring "val_loss" checkpoint_callback = ModelCheckpoint(monitor="val_loss") # 4. Add your callback to the callbacks list trainer = Trainer(callbacks=[checkpoint_callback])
What¶
By default, the
ModelCheckpoint
callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specifysave_weights_only=True
.
Where¶
It gives you the ability to specify the
dirpath
andfilename
for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged usinglog()
.
from pytorch_lightning.callbacks import ModelCheckpoint # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt checkpoint_callback = ModelCheckpoint( dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", )
The ModelCheckpoint
callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub
and the Lightning Team will be happy to integrate/help integrate it.
Customize Checkpointing¶
Warning
The Checkpoint IO API is experimental and subject to change.
Lightning supports modifying the checkpointing save/load functionality through the CheckpointIO
. This encapsulates the save/load logic
that is managed by the Strategy
. CheckpointIO
is different from on_save_checkpoint()
and on_load_checkpoint()
methods as it determines how the checkpoint is saved/loaded to storage rather than
what’s saved in the checkpoint.
Built-in Checkpoint IO Plugins¶
Plugin |
Description |
---|---|
CheckpointIO that utilizes |
|
CheckpointIO that utilizes |
Custom Checkpoint IO Plugin¶
CheckpointIO
can be extended to include your custom save/load functionality to and from a path. The CheckpointIO
object can be passed to either a Trainer
directly or a Strategy
as shown below:
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.strategies import SingleDeviceStrategy
class CustomCheckpointIO(CheckpointIO):
def save_checkpoint(self, checkpoint, path, storage_options=None):
...
def load_checkpoint(self, path, storage_options=None):
...
def remove_checkpoint(self, path):
...
custom_checkpoint_io = CustomCheckpointIO()
# Either pass into the Trainer object
model = MyModel()
trainer = Trainer(
plugins=[custom_checkpoint_io],
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
# or pass into Strategy
model = MyModel()
device = torch.device("cpu")
trainer = Trainer(
strategy=SingleDeviceStrategy(device, checkpoint_io=custom_checkpoint_io),
callbacks=ModelCheckpoint(save_last=True),
)
trainer.fit(model)
Note
Some strategies like DeepSpeedStrategy
do not support custom CheckpointIO
as checkpointing logic is not modifiable.
Managing Remote Filesystems¶
Lightning supports saving and loading checkpoints from a variety of filesystems, including local filesystems and several cloud storage providers.
Check out Remote Filesystems document for more info.