Save the model periodically by monitoring a quantity.

Model Checkpointing

Automatically save model checkpoints during training.

class pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, every_n_val_epochs=None)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model periodically by monitoring a quantity. Every metric logged with log() or log_dict() in LightningModule is a candidate for the monitor key. For more information, see Saving and loading weights.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

  • dirpath (Union[str, Path, None]) –

    directory to save the model file.


    # custom path
    # saves a file like: my/path/epoch=0-step=10.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s default_root_dir or weights_save_path arguments, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (Optional[str]) –

    checkpoint filename. Can contain named formatting options to be auto-filled.


    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )

    By default, filename is None and will be set to '{epoch}-{step}'.

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Optional[bool]) – When True, always saves the model at the end of the epoch to a file last.ckpt. Default: None.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v1.

  • mode (str) – one of {min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.

  • auto_insert_metric_name (bool) – When True, the checkpoints filenames will contain the metric name. For example, filename='checkpoint_{epoch:02d}-{acc:02d} with epoch 1 and acc 80 will resolve to checkpoint_epoch=01-acc=80.ckp. Is useful to set it to False when metric names contain / as this will result in extra folders.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (

  • every_n_train_steps (Optional[int]) – Number of training steps between checkpoints. If every_n_train_steps == None or every_n_train_steps == 0, we skip saving during training. To disable, set every_n_train_steps = 0. This value must be None or non-negative. This must be mutually exclusive with train_time_interval and every_n_epochs.

  • train_time_interval (Optional[timedelta]) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with every_n_train_steps and every_n_epochs.

  • every_n_epochs (Optional[int]) – Number of epochs between checkpoints. If every_n_epochs == None or every_n_epochs == 0, we skip saving when the epoch ends. To disable, set every_n_epochs = 0. This value must be None or non-negative. This must be mutually exclusive with every_n_train_steps and train_time_interval. Setting both ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False) and Trainer(max_epochs=N, check_val_every_n_epoch=M) will only save checkpoints at epochs 0 < E <= N where both values for every_n_epochs and check_val_every_n_epoch evenly divide E.

  • save_on_train_epoch_end (Optional[bool]) – Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation.

  • every_n_val_epochs (Optional[int]) –

    Number of epochs between checkpoints.


    This argument has been deprecated in v1.4 and will be removed in v1.6.

    Use every_n_epochs instead.


For extra customization, ModelCheckpoint includes the following attributes:



  • FILE_EXTENSION = ".ckpt"


For example, you can change the default last checkpoint name by doing checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ModelCheckpoint callbacks.

  • MisconfigurationException – If save_top_k is smaller than -1, if monitor is None and save_top_k is none of None, -1, and 0, or if mode is none of "min" or "max".

  • ValueError – If trainer.save_checkpoint is None.


>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
...     auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...


Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:

monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end

Read more: Persisting State

file_exists(filepath, trainer)[source]

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

Return type


format_checkpoint_name(metrics, filename=None, ver=None)[source]

Generate a filename according to the defined template.


>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name({}))
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
Return type



Called when the trainer initialization ends, model has not yet been set.

Return type


on_load_checkpoint(trainer, pl_module, callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Return type



The on_load_checkpoint won’t be called with an undefined state. If your on_load_checkpoint hook behavior doesn’t rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

on_pretrain_routine_start(trainer, pl_module)[source]

When pretrain routine starts we build the ckpt dir on the fly.

Return type


on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Return type

Dict[str, Any]


The callback state.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

Return type


on_train_end(trainer, pl_module)[source]

Save a checkpoint when training stops.

This will only save a checkpoint if save_last is also enabled as the monitor metrics logged during training/validation steps or end of epochs are not guaranteed to be available at this stage.

Return type


on_train_epoch_end(trainer, pl_module)[source]

Save a checkpoint at the end of the training epoch.

Return type


on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type


on_validation_end(trainer, pl_module)[source]

Save a checkpoint at the end of the validation stage.

Return type



Performs the main logic around saving a checkpoint.

This method runs on all ranks. It is the responsibility of trainer.save_checkpoint to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.

Return type



Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

Return type


property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.