Customize checkpointing behavior (intermediate)¶
Audience: Users looking to customize the checkpointing behavior
Modify checkpointing behavior¶
For fine-grained control over checkpointing behavior, use the ModelCheckpoint
object
from lightning.pytorch.callbacks import ModelCheckpoint
checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss")
trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)
checkpoint_callback.best_model_path
Any value that has been logged via self.log in the LightningModule can be monitored.
class LitModel(L.LightningModule):
def training_step(self, batch, batch_idx):
self.log("my_metric", x)
# 'my_metric' is now able to be monitored
checkpoint_callback = ModelCheckpoint(monitor="my_metric")
Save checkpoints by condition¶
To save checkpoints based on a (when/which/what/where) condition (for example when the validation_loss is lower) modify the ModelCheckpoint
properties.
When¶
When using iterative training which doesn’t have an epoch, you can checkpoint at every
N
training steps by specifyingevery_n_train_steps=N
.You can also control the interval of epochs between checkpoints using
every_n_epochs
, to avoid slowdowns.You can checkpoint at a regular time interval using the
train_time_interval
argument independent of the steps or epochs.In case you are monitoring a training metric, we’d suggest using
save_on_train_epoch_end=True
to ensure the required metric is being accumulated correctly for creating a checkpoint.
Which¶
You can save the last checkpoint when training ends using
save_last
argument.You can save top-K and last-K checkpoints by configuring the
monitor
andsave_top_k
argument.
from lightning.pytorch.callbacks import ModelCheckpoint # saves top-K checkpoints based on "val_loss" metric checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_loss", mode="min", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", ) # saves last-K checkpoints based on "global_step" metric # make sure you log it inside your LightningModule checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="global_step", mode="max", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{global_step}", )
You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:
from lightning.pytorch.callbacks import ModelCheckpoint class LitAutoEncoder(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) # 1. calculate loss loss = F.cross_entropy(y_hat, y) # 2. log val_loss self.log("val_loss", loss) # 3. Init ModelCheckpoint callback, monitoring "val_loss" checkpoint_callback = ModelCheckpoint(monitor="val_loss") # 4. Add your callback to the callbacks list trainer = Trainer(callbacks=[checkpoint_callback])
What¶
By default, the
ModelCheckpoint
callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specifysave_weights_only=True
.
Where¶
By default, the
ModelCheckpoint
will save files into theTrainer.log_dir
. It gives you the ability to specify thedirpath
andfilename
for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged usinglog()
.
from lightning.pytorch.callbacks import ModelCheckpoint # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt checkpoint_callback = ModelCheckpoint( dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", )
The ModelCheckpoint
callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub
and the Lightning Team will be happy to integrate/help integrate it.
Save checkpoints manually¶
You can manually save checkpoints and restore your model from the checkpointed state using save_checkpoint()
and load_from_checkpoint()
.
model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")
# load the checkpoint later as normal
new_model = MyLightningModule.load_from_checkpoint(checkpoint_path="example.ckpt")
Manual saving with distributed training¶
In distributed training cases where a model is running across many machines, Lightning ensures that only one checkpoint is saved instead of a model per machine. This requires no code changes as seen below:
trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)
# Saves only on the main process
# Handles strategy-specific saving logic like XLA, FSDP, DeepSpeed etc.
trainer.save_checkpoint("example.ckpt")
By using save_checkpoint()
instead of torch.save
, you make your code agnostic to the distributed training strategy being used.
It will ensure that checkpoints are saved correctly in a multi-process setting, avoiding race conditions, deadlocks and other common issues that normally require boilerplate code to handle properly.