Customize checkpointing behavior (intermediate)

Audience: Users looking to customize the checkpointing behavior


Modify checkpointing behavior

For fine-grained control over checkpointing behavior, use the ModelCheckpoint object

from lightning.pytorch.callbacks import ModelCheckpoint

checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss")
trainer = Trainer(callbacks=[checkpoint_callback])
trainer.fit(model)
checkpoint_callback.best_model_path

Any value that has been logged via self.log in the LightningModule can be monitored.

class LitModel(L.LightningModule):
    def training_step(self, batch, batch_idx):
        self.log("my_metric", x)


# 'my_metric' is now able to be monitored
checkpoint_callback = ModelCheckpoint(monitor="my_metric")

Save checkpoints by condition

To save checkpoints based on a (when/which/what/where) condition (for example when the validation_loss is lower) modify the ModelCheckpoint properties.

When

  • When using iterative training which doesn’t have an epoch, you can checkpoint at every N training steps by specifying every_n_train_steps=N.

  • You can also control the interval of epochs between checkpoints using every_n_epochs, to avoid slowdowns.

  • You can checkpoint at a regular time interval using the train_time_interval argument independent of the steps or epochs.

  • In case you are monitoring a training metric, we’d suggest using save_on_train_epoch_end=True to ensure the required metric is being accumulated correctly for creating a checkpoint.

Which

  • You can save the last checkpoint when training ends using save_last argument.

  • You can save top-K and last-K checkpoints by configuring the monitor and save_top_k argument.


from lightning.pytorch.callbacks import ModelCheckpoint


# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="val_loss",
    mode="min",
    dirpath="my/path/",
    filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
)

# saves last-K checkpoints based on "global_step" metric
# make sure you log it inside your LightningModule
checkpoint_callback = ModelCheckpoint(
    save_top_k=10,
    monitor="global_step",
    mode="max",
    dirpath="my/path/",
    filename="sample-mnist-{epoch:02d}-{global_step}",
)
  • You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss:


from lightning.pytorch.callbacks import ModelCheckpoint


class LitAutoEncoder(LightningModule):
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.backbone(x)

        # 1. calculate loss
        loss = F.cross_entropy(y_hat, y)

        # 2. log val_loss
        self.log("val_loss", loss)


# 3. Init ModelCheckpoint callback, monitoring "val_loss"
checkpoint_callback = ModelCheckpoint(monitor="val_loss")

# 4. Add your callback to the callbacks list
trainer = Trainer(callbacks=[checkpoint_callback])

What

  • By default, the ModelCheckpoint callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specify save_weights_only=True.

Where

  • By default, the ModelCheckpoint will save files into the Trainer.log_dir. It gives you the ability to specify the dirpath and filename for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged using log().


from lightning.pytorch.callbacks import ModelCheckpoint


# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
checkpoint_callback = ModelCheckpoint(
    dirpath="my/path/",
    filename="sample-mnist-{epoch:02d}-{val_loss:.2f}",
)

The ModelCheckpoint callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub and the Lightning Team will be happy to integrate/help integrate it.


Save checkpoints manually

You can manually save checkpoints and restore your model from the checkpointed state using save_checkpoint() and load_from_checkpoint().

model = MyLightningModule(hparams)
trainer.fit(model)
trainer.save_checkpoint("example.ckpt")

# load the checkpoint later as normal
new_model = MyLightningModule.load_from_checkpoint(checkpoint_path="example.ckpt")

Manual saving with distributed training

In distributed training cases where a model is running across many machines, Lightning ensures that only one checkpoint is saved instead of a model per machine. This requires no code changes as seen below:

trainer = Trainer(strategy="ddp")
model = MyLightningModule(hparams)
trainer.fit(model)

# Saves only on the main process
# Handles strategy-specific saving logic like XLA, FSDP, DeepSpeed etc.
trainer.save_checkpoint("example.ckpt")

By using save_checkpoint() instead of torch.save, you make your code agnostic to the distributed training strategy being used. It will ensure that checkpoints are saved correctly in a multi-process setting, avoiding race conditions, deadlocks and other common issues that normally require boilerplate code to handle properly.