ModelCheckpoint

class lightning.pytorch.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, enable_version_counter=True)[source]

Bases: Checkpoint

Save the model periodically by monitoring a quantity. Every metric logged with log() or log_dict() is a candidate for the monitor key. For more information, see Checkpointing.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters:
  • dirpath (Union[str, Path, None]) –

    directory to save the model file.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0-step=10.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
    

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s default_root_dir argument, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (Optional[str]) –

    checkpoint filename. Can contain named formatting options to be auto-filled.

    Example:

    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filename is None and will be set to '{epoch}-{step}', where “epoch” and “step” match the number of finished epoch and optimizer steps respectively.

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Union[bool, Literal['link'], None]) – When True, saves a last.ckpt copy whenever a checkpoint file gets saved. Can be set to 'link' on a local filesystem to create a symbolic link. This allows accessing the latest checkpoint in a deterministic manner. Default: None.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. If save_top_k == 0, no models are saved. If save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. If save_top_k >= 2 and the callback is called multiple times inside an epoch, and the filename remains unchanged, the name of the saved file will be appended with a version count starting with v1 to avoid collisions unless enable_version_counter is set to False. The version counter is unrelated to the top-k ranking of the checkpoint, and we recommend formatting the filename to include the monitored metric to avoid collisions.

  • mode (str) – one of {min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.

  • auto_insert_metric_name (bool) – When True, the checkpoints filenames will contain the metric name. For example, filename='checkpoint_{epoch:02d}-{acc:02.0f} with epoch 1 and acc 1.12 will resolve to checkpoint_epoch=01-acc=01.ckpt. Is useful to set it to False when metric names contain / as this will result in extra folders. For example, filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False

  • save_weights_only (bool) – if True, then only the model’s weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.

  • every_n_train_steps (Optional[int]) – Number of training steps between checkpoints. If every_n_train_steps == None or every_n_train_steps == 0, we skip saving during training. To disable, set every_n_train_steps = 0. This value must be None or non-negative. This must be mutually exclusive with train_time_interval and every_n_epochs.

  • train_time_interval (Optional[timedelta]) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with every_n_train_steps and every_n_epochs.

  • every_n_epochs (Optional[int]) – Number of epochs between checkpoints. This value must be None or non-negative. To disable saving top-k checkpoints, set every_n_epochs = 0. This argument does not impact the saving of save_last=True checkpoints. If all of every_n_epochs, every_n_train_steps and train_time_interval are None, we save a checkpoint at the end of every epoch (equivalent to every_n_epochs = 1). If every_n_epochs == None and either every_n_train_steps != None or train_time_interval != None, saving at the end of each epoch is disabled (equivalent to every_n_epochs = 0). This must be mutually exclusive with every_n_train_steps and train_time_interval. Setting both ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False) and Trainer(max_epochs=N, check_val_every_n_epoch=M) will only save checkpoints at epochs 0 < E <= N where both values for every_n_epochs and check_val_every_n_epoch evenly divide E.

  • save_on_train_epoch_end (Optional[bool]) – Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation.

  • enable_version_counter (bool) – Whether to append a version to the existing file name. If this is False, then the checkpoint files will be overwritten.

Note

For extra customization, ModelCheckpoint includes the following attributes:

  • CHECKPOINT_JOIN_CHAR = "-"

  • CHECKPOINT_EQUALS_CHAR = "="

  • CHECKPOINT_NAME_LAST = "last"

  • FILE_EXTENSION = ".ckpt"

  • STARTING_VERSION = 1

For example, you can change the default last checkpoint name by doing checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ModelCheckpoint callbacks.

If the checkpoint’s dirpath changed from what it was before while resuming the training, only best_model_path will be reloaded and a warning will be issued.

Raises:
  • MisconfigurationException – If save_top_k is smaller than -1, if monitor is None and save_top_k is none of None, -1, and 0, or if mode is none of "min" or "max".

  • ValueError – If trainer.save_checkpoint is None.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
...     auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

Tip

Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:

monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval

Read more: Persisting Callback State

file_exists(filepath, trainer)[source]

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

Return type:

bool

format_checkpoint_name(metrics, filename=None, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
'epoch=2.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name({}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
'step=0.ckpt'
Return type:

str

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Save a checkpoint at the end of the training epoch.

Return type:

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type:

None

on_validation_end(trainer, pl_module)[source]

Save a checkpoint at the end of the validation stage.

Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

Dict[str, Any]

Returns:

A dictionary containing callback state.

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

Return type:

None

property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.