Shortcuts

ModelCheckpoint

class pytorch_lightning.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, every_n_val_epochs=None)[source]

Bases: pytorch_lightning.callbacks.base.Callback

Save the model periodically by monitoring a quantity. Every metric logged with log() or log_dict() in LightningModule is a candidate for the monitor key. For more information, see Saving and loading weights.

After training finishes, use best_model_path to retrieve the path to the best checkpoint file and best_model_score to retrieve its score.

Parameters
  • dirpath (Union[str, Path, None]) –

    directory to save the model file.

    Example:

    # custom path
    # saves a file like: my/path/epoch=0-step=10.ckpt
    >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
    

    By default, dirpath is None and will be set at runtime to the location specified by Trainer’s default_root_dir or weights_save_path arguments, and if the Trainer uses a logger, the path will also contain logger name and version.

  • filename (Optional[str]) –

    checkpoint filename. Can contain named formatting options to be auto-filled.

    Example:

    # save any arbitrary metrics like `val_loss`, etc. in name
    # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt
    >>> checkpoint_callback = ModelCheckpoint(
    ...     dirpath='my/path',
    ...     filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}'
    ... )
    

    By default, filename is None and will be set to '{epoch}-{step}'.

  • monitor (Optional[str]) – quantity to monitor. By default it is None which saves a checkpoint only for the last epoch.

  • verbose (bool) – verbosity mode. Default: False.

  • save_last (Optional[bool]) – When True, always saves the model at the end of the epoch to a file last.ckpt. Default: None.

  • save_top_k (int) – if save_top_k == k, the best k models according to the quantity monitored will be saved. if save_top_k == 0, no models are saved. if save_top_k == -1, all models are saved. Please note that the monitors are checked every every_n_epochs epochs. if save_top_k >= 2 and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with v1.

  • mode (str) – one of {min, max}. If save_top_k != 0, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For 'val_acc', this should be 'max', for 'val_loss' this should be 'min', etc.

  • auto_insert_metric_name (bool) – When True, the checkpoints filenames will contain the metric name. For example, filename='checkpoint_{epoch:02d}-{acc:02d} with epoch 1 and acc 80 will resolve to checkpoint_epoch=01-acc=80.ckp. Is useful to set it to False when metric names contain / as this will result in extra folders.

  • save_weights_only (bool) – if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)).

  • every_n_train_steps (Optional[int]) – Number of training steps between checkpoints. If every_n_train_steps == None or every_n_train_steps == 0, we skip saving during training. To disable, set every_n_train_steps = 0. This value must be None or non-negative. This must be mutually exclusive with train_time_interval and every_n_epochs.

  • train_time_interval (Optional[timedelta]) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with every_n_train_steps and every_n_epochs.

  • every_n_epochs (Optional[int]) – Number of epochs between checkpoints. If every_n_epochs == None or every_n_epochs == 0, we skip saving when the epoch ends. To disable, set every_n_epochs = 0. This value must be None or non-negative. This must be mutually exclusive with every_n_train_steps and train_time_interval. Setting both ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False) and Trainer(max_epochs=N, check_val_every_n_epoch=M) will only save checkpoints at epochs 0 < E <= N where both values for every_n_epochs and check_val_every_n_epoch evenly divide E.

  • save_on_train_epoch_end (Optional[bool]) – Whether to run checkpointing at the end of the training epoch. If this is False, then the check runs at the end of the validation.

  • every_n_val_epochs (Optional[int]) –

    Number of epochs between checkpoints.

    Warning

    This argument has been deprecated in v1.4 and will be removed in v1.6.

    Use every_n_epochs instead.

Note

For extra customization, ModelCheckpoint includes the following attributes:

  • CHECKPOINT_JOIN_CHAR = "-"

  • CHECKPOINT_NAME_LAST = "last"

  • FILE_EXTENSION = ".ckpt"

  • STARTING_VERSION = 1

For example, you can change the default last checkpoint name by doing checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"

If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ModelCheckpoint callbacks.

Raises
  • MisconfigurationException – If save_top_k is smaller than -1, if monitor is None and save_top_k is none of None, -1, and 0, or if mode is none of "min" or "max".

  • ValueError – If trainer.save_checkpoint is None.

Example:

>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import ModelCheckpoint

# saves checkpoints to 'my/path/' at every epoch
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
>>> trainer = Trainer(callbacks=[checkpoint_callback])

# save epoch and val_loss in name
# saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val_loss',
...     dirpath='my/path/',
...     filename='sample-mnist-{epoch:02d}-{val_loss:.2f}'
... )

# save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard
# or Neptune, due to the presence of characters like '=' or '/')
# saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt
>>> checkpoint_callback = ModelCheckpoint(
...     monitor='val/loss',
...     dirpath='my/path/',
...     filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}',
...     auto_insert_metric_name=False
... )

# retrieve the best checkpoint after training
checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
trainer = Trainer(callbacks=[checkpoint_callback])
model = ...
trainer.fit(model)
checkpoint_callback.best_model_path

Tip

Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:

monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end

Read more: Persisting State

file_exists(filepath, trainer)[source]

Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.

Return type

bool

format_checkpoint_name(metrics, filename=None, ver=None)[source]

Generate a filename according to the defined template.

Example:

>>> tmpdir = os.path.dirname(__file__)
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0)))
'epoch=0.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5)))
'epoch=005.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-val_loss=0.12.ckpt'
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}'))
'epoch=2.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir,
... filename='epoch={epoch}-validation_loss={val_loss:.2f}',
... auto_insert_metric_name=False)
>>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456)))
'epoch=2-validation_loss=0.12.ckpt'
>>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}')
>>> os.path.basename(ckpt.format_checkpoint_name({}))
'missing=0.ckpt'
>>> ckpt = ModelCheckpoint(filename='{step}')
>>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0)))
'step=0.ckpt'
Return type

str

on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

Return type

None

on_load_checkpoint(trainer, pl_module, callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Parameters
Return type

None

Note

The on_load_checkpoint won’t be called with an undefined state. If your on_load_checkpoint hook behavior doesn’t rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

on_pretrain_routine_start(trainer, pl_module)[source]

When pretrain routine starts we build the ckpt dir on the fly.

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
Return type

Dict[str, Any]

Returns

The callback state.

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Save checkpoint on train batch end if we meet the criteria for every_n_train_steps

Return type

None

on_train_end(trainer, pl_module)[source]

Save a checkpoint when training stops.

This will only save a checkpoint if save_last is also enabled as the monitor metrics logged during training/validation steps or end of epochs are not guaranteed to be available at this stage.

Return type

None

on_train_epoch_end(trainer, pl_module)[source]

Save a checkpoint at the end of the training epoch.

Return type

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type

None

on_validation_end(trainer, pl_module)[source]

Save a checkpoint at the end of the validation stage.

Return type

None

save_checkpoint(trainer)[source]

Performs the main logic around saving a checkpoint.

This method runs on all ranks. It is the responsibility of trainer.save_checkpoint to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.

Return type

None

to_yaml(filepath=None)[source]

Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.

Return type

None

property state_key: str

Identifier for the state of the callback.

Used to store and retrieve a callback’s state from the checkpoint dictionary by checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.

You are viewing an outdated version of PyTorch Lightning Docs

Click here to view the latest version→