ModelCheckpoint¶
- class pytorch_lightning.callbacks.ModelCheckpoint(dirpath=None, filename=None, monitor=None, verbose=False, save_last=None, save_top_k=1, save_weights_only=False, mode='min', auto_insert_metric_name=True, every_n_train_steps=None, train_time_interval=None, every_n_epochs=None, save_on_train_epoch_end=None, period=None, every_n_val_epochs=None)[source]¶
Bases:
pytorch_lightning.callbacks.base.Callback
Save the model periodically by monitoring a quantity. Every metric logged with
log()
orlog_dict()
in LightningModule is a candidate for the monitor key. For more information, see Saving and loading weights.After training finishes, use
best_model_path
to retrieve the path to the best checkpoint file andbest_model_score
to retrieve its score.- Parameters
dirpath¶ (
Union
[str
,Path
,None
]) –directory to save the model file.
Example:
# custom path # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is
None
and will be set at runtime to the location specified byTrainer
’sdefault_root_dir
orweights_save_path
arguments, and if the Trainer uses a logger, the path will also contain logger name and version.checkpoint filename. Can contain named formatting options to be auto-filled.
Example:
# save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( ... dirpath='my/path', ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... )
By default, filename is
None
and will be set to'{epoch}-{step}'
.monitor¶ (
Optional
[str
]) – quantity to monitor. By default it isNone
which saves a checkpoint only for the last epoch.save_last¶ (
Optional
[bool
]) – WhenTrue
, always saves the model at the end of the epoch to a file last.ckpt. Default:None
.save_top_k¶ (
int
) – ifsave_top_k == k
, the best k models according to the quantity monitored will be saved. ifsave_top_k == 0
, no models are saved. ifsave_top_k == -1
, all models are saved. Please note that the monitors are checked everyperiod
epochs. ifsave_top_k >= 2
and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting withv1
.mode¶ (
str
) – one of {min, max}. Ifsave_top_k != 0
, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For'val_acc'
, this should be'max'
, for'val_loss'
this should be'min'
, etc.auto_insert_metric_name¶ (
bool
) – WhenTrue
, the checkpoints filenames will contain the metric name. For example,filename='checkpoint_{epoch:02d}-{acc:02d}
with epoch 1 and acc 80 will resolve tocheckpoint_epoch=01-acc=80.ckp
. Is useful to set it toFalse
when metric names contain/
as this will result in extra folders.save_weights_only¶ (
bool
) – ifTrue
, then only the model’s weights will be saved (model.save_weights(filepath)
), else the full model is saved (model.save(filepath)
).every_n_train_steps¶ (
Optional
[int
]) – Number of training steps between checkpoints. Ifevery_n_train_steps == None or every_n_train_steps == 0
, we skip saving during training. To disable, setevery_n_train_steps = 0
. This value must beNone
or non-negative. This must be mutually exclusive withtrain_time_interval
andevery_n_epochs
.train_time_interval¶ (
Optional
[timedelta
]) – Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive withevery_n_train_steps
andevery_n_epochs
.every_n_epochs¶ (
Optional
[int
]) – Number of epochs between checkpoints. Ifevery_n_epochs == None or every_n_epochs == 0
, we skip saving when the epoch ends. To disable, setevery_n_epochs = 0
. This value must beNone
or non-negative. This must be mutually exclusive withevery_n_train_steps
andtrain_time_interval
. Setting bothModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)
andTrainer(max_epochs=N, check_val_every_n_epoch=M)
will only save checkpoints at epochs 0 < E <= N where both values forevery_n_epochs
andcheck_val_every_n_epoch
evenly divide E.save_on_train_epoch_end¶ (
Optional
[bool
]) – Whether to run checkpointing at the end of the training epoch. If this isFalse
, then the check runs at the end of the validation.Interval (number of epochs) between checkpoints.
Warning
This argument has been deprecated in v1.3 and will be removed in v1.5.
Use
every_n_epochs
instead.every_n_val_epochs¶ (
Optional
[int
]) –Number of epochs between checkpoints.
Warning
This argument has been deprecated in v1.4 and will be removed in v1.6.
Use
every_n_epochs
instead.
Note
For extra customization, ModelCheckpoint includes the following attributes:
CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"
STARTING_VERSION = 1
For example, you can change the default last checkpoint name by doing
checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"
If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple
ModelCheckpoint
callbacks.- Raises
MisconfigurationException – If
save_top_k
is neitherNone
nor more than or equal to-1
, ifmonitor
isNone
andsave_top_k
is none ofNone
,-1
, and0
, or ifmode
is none of"min"
or"max"
.ValueError – If
trainer.save_checkpoint
isNone
.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard # or Neptune, due to the presence of characters like '=' or '/') # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val/loss', ... dirpath='my/path/', ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', ... auto_insert_metric_name=False ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path
- file_exists(filepath, trainer)[source]¶
Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks.
- Return type
- format_checkpoint_name(metrics, ver=None)[source]¶
Generate a filename according to the defined template.
Example:
>>> tmpdir = os.path.dirname(__file__) >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=0))) 'epoch=0.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch:03d}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=5))) 'epoch=005.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{epoch}-{val_loss:.2f}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-val_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, ... filename='epoch={epoch}-validation_loss={val_loss:.2f}', ... auto_insert_metric_name=False) >>> os.path.basename(ckpt.format_checkpoint_name(dict(epoch=2, val_loss=0.123456))) 'epoch=2-validation_loss=0.12.ckpt' >>> ckpt = ModelCheckpoint(dirpath=tmpdir, filename='{missing:d}') >>> os.path.basename(ckpt.format_checkpoint_name({})) 'missing=0.ckpt' >>> ckpt = ModelCheckpoint(filename='{step}') >>> os.path.basename(ckpt.format_checkpoint_name(dict(step=0))) 'step=0.ckpt'
- Return type
- on_load_checkpoint(trainer, pl_module, callback_state)[source]¶
Called when loading a model checkpoint, use to reload state.
- Parameters
pl_module¶ (
LightningModule
) – the currentLightningModule
instance.callback_state¶ (
Dict
[str
,Any
]) – the callback state returned byon_save_checkpoint
.
- Return type
Note
The
on_load_checkpoint
won’t be called with an undefined state. If youron_load_checkpoint
hook behavior doesn’t rely on a state, you will still need to overrideon_save_checkpoint
to return adummy state
.
- on_pretrain_routine_start(trainer, pl_module)[source]¶
When pretrain routine starts we build the ckpt dir on the fly
- Return type
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a model checkpoint, use to persist state.
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]¶
Save checkpoint on train batch end if we meet the criteria for every_n_train_steps
- Return type
- on_train_end(trainer, pl_module)[source]¶
Save a checkpoint when training stops.
This will only save a checkpoint if save_last is also enabled as the monitor metrics logged during training/validation steps or end of epochs are not guaranteed to be available at this stage.
- Return type
- on_train_epoch_end(trainer, pl_module, unused=None)[source]¶
Save a checkpoint at the end of the training epoch.
- on_validation_end(trainer, pl_module)[source]¶
Save a checkpoint at the end of the validation stage.
- Return type
- save_checkpoint(trainer, unused=None)[source]¶
Performs the main logic around saving a checkpoint. This method runs on all ranks. It is the responsibility of trainer.save_checkpoint to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases.
- Return type