ModelCheckpoint¶
- class pytorch_lightning.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None)[source]¶
Bases:
pytorch_lightning.callbacks.checkpoint.Checkpoint
Save the model periodically by monitoring a quantity. Every metric logged with
log()
orlog_dict()
in LightningModule is a candidate for the monitor key. For more information, see Checkpointing.After training finishes, use
best_model_path
to retrieve the path to the best checkpoint file andbest_model_score
to retrieve its score.- Parameters:
dirpath¶ (
Union
[str
,Path
,None
]) –directory to save the model file.
Example:
# custom path # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is
None
and will be set at runtime to the location specified byTrainer
’sdefault_root_dir
orweights_save_path
arguments, and if the Trainer uses a logger, the path will also contain logger name and version.checkpoint filename. Can contain named formatting options to be auto-filled.
Example:
# save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( ... dirpath='my/path', ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filename is
None
and will be set to'{epoch}-{step}'
.monitor¶ (
Optional
[str
]) – quantity to monitor. By default it isNone
which saves a checkpoint only for the last epoch.save_last¶ (
Optional
[bool
]) – WhenTrue
, saves an exact copy of the checkpoint to a file last.ckpt whenever a checkpoint file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default:None
.save_top_k¶ (
int
) – ifsave_top_k == k
, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0
, no models are saved. ifsave_top_k == -1
, all models are saved. Please note that the monitors are checked everyevery_n_epochs
epochs. ifsave_top_k >= 2
and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting withv1
.mode¶ (
str
) – one of {min, max}. Ifsave_top_k != 0
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For'val_acc'
, this should be'max'
, for'val_loss'
this should be'min'
, etc.auto_insert_metric_name¶ (
bool
) – WhenTrue
, the checkpoints filenames will contain the metric name. For example,filename='checkpoint_{epoch:02d}-{acc:02.0f}
with epoch1
and acc1.12
will resolve tocheckpoint_epoch=01-acc=01.ckpt
. Is useful to set it toFalse
when metric names contain/
as this will result in extra folders. For example,filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False
save_weights_only¶ (
bool
) – ifTrue
, then only the model’s weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.every_n_train_steps¶ (
Optional
[int
]) – Number of training steps between checkpoints. Ifevery_n_train_steps == None or every_n_train_steps == 0
, we skip saving during training. To disable, setevery_n_train_steps = 0
. This value must beNone
or non-negative. This must be mutually exclusive withtrain_time_interval
andevery_n_epochs
.train_time_interval¶ (
Optional
[timedelta
]) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive withevery_n_train_steps
andevery_n_epochs
.every_n_epochs¶ (
Optional
[int
]) – Number of epochs between checkpoints. This value must beNone
or non-negative. To disable saving top-k checkpoints, setevery_n_epochs = 0
. This argument does not impact the saving ofsave_last=True
checkpoints. If all ofevery_n_epochs
,every_n_train_steps
andtrain_time_interval
areNone
, we save a checkpoint at the end of every epoch (equivalent toevery_n_epochs = 1
). Ifevery_n_epochs == None
and eitherevery_n_train_steps != None
ortrain_time_interval != None
, saving at the end of each epoch is disabled (equivalent toevery_n_epochs = 0
). This must be mutually exclusive withevery_n_train_steps
andtrain_time_interval
. Setting bothModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)
andTrainer(max_epochs=N, check_val_every_n_epoch=M)
will only save checkpoints at epochs 0 < E <= N where both values forevery_n_epochs
andcheck_val_every_n_epoch
evenly divide E.save_on_train_epoch_end¶ (
Optional
[bool
]) – Whether to run checkpointing at the end of the training epoch. If this isFalse
, then the check runs at the end of the validation.
Note
For extra customization, ModelCheckpoint includes the following attributes:
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1
For example, you can change the default last checkpoint name by doing
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple
ModelCheckpoint
callbacks.If the checkpoint’s
dirpath
changed from what it was before while resuming the training, onlybest_model_path
will be reloaded and a warning will be issued.- Raises:
MisconfigurationException – If
save_top_k
is smaller than-1
, ifmonitor
isNone
andsave_top_k
is none ofNone
,-1
, and0
, or ifmode
is none of"min"
or"max"
.ValueError – If
trainer.save_checkpoint
isNone
.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard # or Neptune, due to the presence of characters like '=' or '/') # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val/loss', ... dirpath='my/path/', ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', ... auto_insert_metric_name=False ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path
Tip
Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments:
monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end
Read more: Persisting Callback State
- file_exists(filepath, trainer)[source]¶
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.
- Return type:
- format_checkpoint_name(metrics, filename=None, ver=None)[source]¶
Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0))) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5))) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.12), filename='{epoch:d}')) 'epoch=2.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', ... auto_insert_metric_name=False) >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-validation_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name({})) 'missing=0.ckpt' >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 'step=0.ckpt'
- Return type:
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Save checkpoint on train batch end if we meet the criteria for every_n_train_steps
- Return type:
- on_train_epoch_end(trainer, pl_module)[source]¶
Save a checkpoint at the end of the training epoch.
- Return type:
- on_validation_end(trainer, pl_module)[source]¶
Save a checkpoint at the end of the validation stage.
- Return type:
- save_checkpoint(trainer)[source]¶
Performs the main logic around saving a checkpoint.
This method runs on all ranks. It is the responsibility of trainer.save_checkpoint to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
- Return type:
- setup(trainer, pl_module, stage=None)[source]¶
Called when fit, validate, test, predict, or tune begins.
- Return type:
- to_yaml(filepath=None)[source]¶
Saves the best_k_models dict containing the checkpoint paths with the corresponding scores to a YAML file.
- Return type:
- property state_key: str¶
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type: