:orphan: .. _checkpointing_intermediate_1: ############################################### Customize checkpointing behavior (intermediate) ############################################### **Audience:** Users looking to customize the checkpointing behavior ---- ***************************** Modify checkpointing behavior ***************************** For fine-grained control over checkpointing behavior, use the :class:`~lightning.pytorch.callbacks.ModelCheckpoint` object .. code-block:: python from lightning.pytorch.callbacks import ModelCheckpoint checkpoint_callback = ModelCheckpoint(dirpath="my/path/", save_top_k=2, monitor="val_loss") trainer = Trainer(callbacks=[checkpoint_callback]) trainer.fit(model) checkpoint_callback.best_model_path Any value that has been logged via *self.log* in the LightningModule can be monitored. .. code-block:: python class LitModel(L.LightningModule): def training_step(self, batch, batch_idx): self.log("my_metric", x) # 'my_metric' is now able to be monitored checkpoint_callback = ModelCheckpoint(monitor="my_metric") ---- ***************************** Save checkpoints by condition ***************************** To save checkpoints based on a (*when/which/what/where*) condition (for example *when* the validation_loss is lower) modify the :class:`~lightning.pytorch.callbacks.ModelCheckpoint` properties. When ==== - When using iterative training which doesn't have an epoch, you can checkpoint at every ``N`` training steps by specifying ``every_n_train_steps=N``. - You can also control the interval of epochs between checkpoints using ``every_n_epochs``, to avoid slowdowns. - You can checkpoint at a regular time interval using the ``train_time_interval`` argument independent of the steps or epochs. - In case you are monitoring a training metric, we'd suggest using ``save_on_train_epoch_end=True`` to ensure the required metric is being accumulated correctly for creating a checkpoint. Which ===== - You can save the last checkpoint when training ends using ``save_last`` argument. - You can save top-K and last-K checkpoints by configuring the ``monitor`` and ``save_top_k`` argument. | .. testcode:: from lightning.pytorch.callbacks import ModelCheckpoint # saves top-K checkpoints based on "val_loss" metric checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="val_loss", mode="min", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", ) # saves last-K checkpoints based on "global_step" metric # make sure you log it inside your LightningModule checkpoint_callback = ModelCheckpoint( save_top_k=10, monitor="global_step", mode="max", dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{global_step}", ) - You can customize the checkpointing behavior to monitor any quantity of your training or validation steps. For example, if you want to update your checkpoints based on your validation loss: | .. testcode:: from lightning.pytorch.callbacks import ModelCheckpoint class LitAutoEncoder(LightningModule): def validation_step(self, batch, batch_idx): x, y = batch y_hat = self.backbone(x) # 1. calculate loss loss = F.cross_entropy(y_hat, y) # 2. log val_loss self.log("val_loss", loss) # 3. Init ModelCheckpoint callback, monitoring "val_loss" checkpoint_callback = ModelCheckpoint(monitor="val_loss") # 4. Add your callback to the callbacks list trainer = Trainer(callbacks=[checkpoint_callback]) What ==== - By default, the ``ModelCheckpoint`` callback saves model weights, optimizer states, etc., but in case you have limited disk space or just need the model weights to be saved you can specify ``save_weights_only=True``. Where ===== - By default, the ``ModelCheckpoint`` will save files into the ``Trainer.log_dir``. It gives you the ability to specify the ``dirpath`` and ``filename`` for your checkpoints. Filename can also be dynamic so you can inject the metrics that are being logged using :meth:`~lightning.pytorch.core.LightningModule.log`. | .. testcode:: from lightning.pytorch.callbacks import ModelCheckpoint # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt checkpoint_callback = ModelCheckpoint( dirpath="my/path/", filename="sample-mnist-{epoch:02d}-{val_loss:.2f}", ) | The :class:`~lightning.pytorch.callbacks.ModelCheckpoint` callback is very robust and should cover 99% of the use-cases. If you find a use-case that is not configured yet, feel free to open an issue with a feature request on GitHub and the Lightning Team will be happy to integrate/help integrate it. ---- ************************* Save checkpoints manually ************************* You can manually save checkpoints and restore your model from the checkpointed state using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` and :meth:`~lightning.pytorch.core.LightningModule.load_from_checkpoint`. .. code-block:: python model = MyLightningModule(hparams) trainer.fit(model) trainer.save_checkpoint("example.ckpt") # load the checkpoint later as normal new_model = MyLightningModule.load_from_checkpoint(checkpoint_path="example.ckpt") Manual saving with distributed training ======================================= In distributed training cases where a model is running across many machines, Lightning ensures that only one checkpoint is saved instead of a model per machine. This requires no code changes as seen below: .. code-block:: python trainer = Trainer(strategy="ddp") model = MyLightningModule(hparams) trainer.fit(model) # Saves only on the main process trainer.save_checkpoint("example.ckpt") Not using :meth:`~lightning.pytorch.trainer.trainer.Trainer.save_checkpoint` can lead to unexpected behavior and potential deadlock. Using other saving functions will result in all devices attempting to save the checkpoint. As a result, we highly recommend using the Trainer's save functionality. If using custom saving functions cannot be avoided, we recommend using the :func:`~lightning.pytorch.utilities.rank_zero.rank_zero_only` decorator to ensure saving occurs only on the main process. Note that this will only work if all ranks hold the exact same state and won't work when using model parallel distributed strategies such as deepspeed or sharded training.