Source code for pytorch_lightning.callbacks.model_checkpoint
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Model Checkpointing===================Automatically save model checkpoints during training."""importloggingimportosimportreimporttimeimportwarningsfromcopyimportdeepcopyfromdatetimeimporttimedeltafromtypingimportAny,Dict,OptionalfromweakrefimportproxyimportnumpyasnpimporttorchimportyamlfromtorchimportTensorimportpytorch_lightningasplfrompytorch_lightning.callbacksimportCheckpointfrompytorch_lightning.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.loggerimport_name,_versionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_info,rank_zero_warnfrompytorch_lightning.utilities.typesimport_PATH,STEP_OUTPUTfrompytorch_lightning.utilities.warningsimportWarningCachelog=logging.getLogger(__name__)warning_cache=WarningCache()
[docs]classModelCheckpoint(Checkpoint):r""" Save the model periodically by monitoring a quantity. Every metric logged with :meth:`~pytorch_lightning.core.module.log` or :meth:`~pytorch_lightning.core.module.log_dict` in LightningModule is a candidate for the monitor key. For more information, see :ref:`checkpointing`. After training finishes, use :attr:`best_model_path` to retrieve the path to the best checkpoint file and :attr:`best_model_score` to retrieve its score. Args: dirpath: directory to save the model file. Example:: # custom path # saves a file like: my/path/epoch=0-step=10.ckpt >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') By default, dirpath is ``None`` and will be set at runtime to the location specified by :class:`~pytorch_lightning.trainer.trainer.Trainer`'s :paramref:`~pytorch_lightning.trainer.trainer.Trainer.default_root_dir` or :paramref:`~pytorch_lightning.trainer.trainer.Trainer.weights_save_path` arguments, and if the Trainer uses a logger, the path will also contain logger name and version. filename: checkpoint filename. Can contain named formatting options to be auto-filled. Example:: # save any arbitrary metrics like `val_loss`, etc. in name # saves a file like: my/path/epoch=2-val_loss=0.02-other_metric=0.03.ckpt >>> checkpoint_callback = ModelCheckpoint( ... dirpath='my/path', ... filename='{epoch}-{val_loss:.2f}-{other_metric:.2f}' ... ) By default, filename is ``None`` and will be set to ``'{epoch}-{step}'``. monitor: quantity to monitor. By default it is ``None`` which saves a checkpoint only for the last epoch. verbose: verbosity mode. Default: ``False``. save_last: When ``True``, saves an exact copy of the checkpoint to a file `last.ckpt` whenever a checkpoint file gets saved. This allows accessing the latest checkpoint in a deterministic manner. Default: ``None``. save_top_k: if ``save_top_k == k``, the best k models according to the quantity monitored will be saved. if ``save_top_k == 0``, no models are saved. if ``save_top_k == -1``, all models are saved. Please note that the monitors are checked every ``every_n_epochs`` epochs. if ``save_top_k >= 2`` and the callback is called multiple times inside an epoch, the name of the saved file will be appended with a version count starting with ``v1``. mode: one of {min, max}. If ``save_top_k != 0``, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For ``'val_acc'``, this should be ``'max'``, for ``'val_loss'`` this should be ``'min'``, etc. auto_insert_metric_name: When ``True``, the checkpoints filenames will contain the metric name. For example, ``filename='checkpoint_{epoch:02d}-{acc:02.0f}`` with epoch ``1`` and acc ``1.12`` will resolve to ``checkpoint_epoch=01-acc=01.ckpt``. Is useful to set it to ``False`` when metric names contain ``/`` as this will result in extra folders. For example, ``filename='epoch={epoch}-step={step}-val_acc={val/acc:.2f}', auto_insert_metric_name=False`` save_weights_only: if ``True``, then only the model's weights will be saved. Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too. every_n_train_steps: Number of training steps between checkpoints. If ``every_n_train_steps == None or every_n_train_steps == 0``, we skip saving during training. To disable, set ``every_n_train_steps = 0``. This value must be ``None`` or non-negative. This must be mutually exclusive with ``train_time_interval`` and ``every_n_epochs``. train_time_interval: Checkpoints are monitored at the specified time interval. For all practical purposes, this cannot be smaller than the amount of time it takes to process a single training batch. This is not guaranteed to execute at the exact time specified, but should be close. This must be mutually exclusive with ``every_n_train_steps`` and ``every_n_epochs``. every_n_epochs: Number of epochs between checkpoints. This value must be ``None`` or non-negative. To disable saving top-k checkpoints, set ``every_n_epochs = 0``. This argument does not impact the saving of ``save_last=True`` checkpoints. If all of ``every_n_epochs``, ``every_n_train_steps`` and ``train_time_interval`` are ``None``, we save a checkpoint at the end of every epoch (equivalent to ``every_n_epochs = 1``). If ``every_n_epochs == None`` and either ``every_n_train_steps != None`` or ``train_time_interval != None``, saving at the end of each epoch is disabled (equivalent to ``every_n_epochs = 0``). This must be mutually exclusive with ``every_n_train_steps`` and ``train_time_interval``. Setting both ``ModelCheckpoint(..., every_n_epochs=V, save_on_train_epoch_end=False)`` and ``Trainer(max_epochs=N, check_val_every_n_epoch=M)`` will only save checkpoints at epochs 0 < E <= N where both values for ``every_n_epochs`` and ``check_val_every_n_epoch`` evenly divide E. save_on_train_epoch_end: Whether to run checkpointing at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation. Note: For extra customization, ModelCheckpoint includes the following attributes: - ``CHECKPOINT_JOIN_CHAR = "-"`` - ``CHECKPOINT_NAME_LAST = "last"`` - ``FILE_EXTENSION = ".ckpt"`` - ``STARTING_VERSION = 1`` For example, you can change the default last checkpoint name by doing ``checkpoint_callback.CHECKPOINT_NAME_LAST = "{epoch}-last"`` If you want to checkpoint every N hours, every M train batches, and/or every K val epochs, then you should create multiple ``ModelCheckpoint`` callbacks. If the checkpoint's ``dirpath`` changed from what it was before while resuming the training, only ``best_model_path`` will be reloaded and a warning will be issued. Raises: MisconfigurationException: If ``save_top_k`` is smaller than ``-1``, if ``monitor`` is ``None`` and ``save_top_k`` is none of ``None``, ``-1``, and ``0``, or if ``mode`` is none of ``"min"`` or ``"max"``. ValueError: If ``trainer.save_checkpoint`` is ``None``. Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import ModelCheckpoint # saves checkpoints to 'my/path/' at every epoch >>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/') >>> trainer = Trainer(callbacks=[checkpoint_callback]) # save epoch and val_loss in name # saves a file like: my/path/sample-mnist-epoch=02-val_loss=0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val_loss', ... dirpath='my/path/', ... filename='sample-mnist-{epoch:02d}-{val_loss:.2f}' ... ) # save epoch and val_loss in name, but specify the formatting yourself (e.g. to avoid problems with Tensorboard # or Neptune, due to the presence of characters like '=' or '/') # saves a file like: my/path/sample-mnist-epoch02-val_loss0.32.ckpt >>> checkpoint_callback = ModelCheckpoint( ... monitor='val/loss', ... dirpath='my/path/', ... filename='sample-mnist-epoch{epoch:02d}-val_loss{val/loss:.2f}', ... auto_insert_metric_name=False ... ) # retrieve the best checkpoint after training checkpoint_callback = ModelCheckpoint(dirpath='my/path/') trainer = Trainer(callbacks=[checkpoint_callback]) model = ... trainer.fit(model) checkpoint_callback.best_model_path .. tip:: Saving and restoring multiple checkpoint callbacks at the same time is supported under variation in the following arguments: *monitor, mode, every_n_train_steps, every_n_epochs, train_time_interval, save_on_train_epoch_end* Read more: :ref:`Persisting Callback State <extensions/callbacks_state:save callback state>` """CHECKPOINT_JOIN_CHAR="-"CHECKPOINT_NAME_LAST="last"FILE_EXTENSION=".ckpt"STARTING_VERSION=1def__init__(self,dirpath:Optional[_PATH]=None,filename:Optional[str]=None,monitor:Optional[str]=None,verbose:bool=False,save_last:Optional[bool]=None,save_top_k:int=1,save_weights_only:bool=False,mode:str="min",auto_insert_metric_name:bool=True,every_n_train_steps:Optional[int]=None,train_time_interval:Optional[timedelta]=None,every_n_epochs:Optional[int]=None,save_on_train_epoch_end:Optional[bool]=None,):super().__init__()self.monitor=monitorself.verbose=verboseself.save_last=save_lastself.save_top_k=save_top_kself.save_weights_only=save_weights_onlyself.auto_insert_metric_name=auto_insert_metric_nameself._save_on_train_epoch_end=save_on_train_epoch_endself._last_global_step_saved=0# no need to save when no steps were takenself._last_time_checked:Optional[float]=Noneself.current_score:Optional[Tensor]=Noneself.best_k_models:Dict[str,Tensor]={}self.kth_best_model_path=""self.best_model_score:Optional[Tensor]=Noneself.best_model_path=""self.last_model_path=""self.kth_value:Tensorself.__init_monitor_mode(mode)self.__init_ckpt_dir(dirpath,filename)self.__init_triggers(every_n_train_steps,every_n_epochs,train_time_interval)self.__validate_init_configuration()@propertydefstate_key(self)->str:returnself._generate_state_key(monitor=self.monitor,mode=self.mode,every_n_train_steps=self._every_n_train_steps,every_n_epochs=self._every_n_epochs,train_time_interval=self._train_time_interval,save_on_train_epoch_end=self._save_on_train_epoch_end,)
[docs]defsetup(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",stage:Optional[str]=None)->None:self.__resolve_ckpt_dir(trainer)assertself.dirpathisnotNoneiftrainer.is_global_zeroandstage=="fit":self.__warn_if_dir_not_empty(self.dirpath)# NOTE: setting these attributes needs to happen as early as possible BEFORE reloading callback states,# because the attributes are part of the state_key which needs to be fully defined before reloading.ifself._save_on_train_epoch_endisNone:# if the user runs validation multiple times per training epoch or multiple training epochs without# validation, then we run after validation instead of on train epoch endself._save_on_train_epoch_end=trainer.val_check_interval==1.0andtrainer.check_val_every_n_epoch==1
[docs]defon_train_batch_end(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",outputs:STEP_OUTPUT,batch:Any,batch_idx:int,)->None:"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""ifself._should_skip_saving_checkpoint(trainer):returnskip_batch=self._every_n_train_steps<1or(trainer.global_step%self._every_n_train_steps!=0)train_time_interval=self._train_time_intervalskip_time=Truenow=time.monotonic()iftrain_time_interval:prev_time_check=self._last_time_checkedskip_time=prev_time_checkisNoneor(now-prev_time_check)<train_time_interval.total_seconds()# in case we have time differences across ranks# broadcast the decision on whether to checkpoint from rank 0 to avoid possible hangsskip_time=trainer.strategy.broadcast(skip_time)ifskip_batchandskip_time:returnifnotskip_time:self._last_time_checked=nowmonitor_candidates=self._monitor_candidates(trainer)self._save_topk_checkpoint(trainer,monitor_candidates)self._save_last_checkpoint(trainer,monitor_candidates)
[docs]defon_train_epoch_end(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:"""Save a checkpoint at the end of the training epoch."""ifnotself._should_skip_saving_checkpoint(trainer)andself._save_on_train_epoch_end:monitor_candidates=self._monitor_candidates(trainer)ifself._every_n_epochs>=1and(trainer.current_epoch+1)%self._every_n_epochs==0:self._save_topk_checkpoint(trainer,monitor_candidates)self._save_last_checkpoint(trainer,monitor_candidates)
[docs]defon_validation_end(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:"""Save a checkpoint at the end of the validation stage."""ifnotself._should_skip_saving_checkpoint(trainer)andnotself._save_on_train_epoch_end:monitor_candidates=self._monitor_candidates(trainer)ifself._every_n_epochs>=1and(trainer.current_epoch+1)%self._every_n_epochs==0:self._save_topk_checkpoint(trainer,monitor_candidates)self._save_last_checkpoint(trainer,monitor_candidates)
[docs]defload_state_dict(self,state_dict:Dict[str,Any])->None:dirpath_from_ckpt=state_dict.get("dirpath",self.dirpath)ifself.dirpath==dirpath_from_ckpt:self.best_model_score=state_dict["best_model_score"]self.kth_best_model_path=state_dict.get("kth_best_model_path",self.kth_best_model_path)self.kth_value=state_dict.get("kth_value",self.kth_value)self.best_k_models=state_dict.get("best_k_models",self.best_k_models)self.last_model_path=state_dict.get("last_model_path",self.last_model_path)else:warnings.warn(f"The dirpath has changed from {dirpath_from_ckpt!r} to {self.dirpath!r},"" therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and"" `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.")self.best_model_path=state_dict["best_model_path"]
[docs]defsave_checkpoint(self,trainer:"pl.Trainer")->None:# pragma: no-cover"""Performs the main logic around saving a checkpoint. This method runs on all ranks. It is the responsibility of `trainer.save_checkpoint` to correctly handle the behaviour in distributed training, i.e., saving only on rank 0 for data parallel use cases. """rank_zero_deprecation(f"`{self.__class__.__name__}.save_checkpoint()` was deprecated in v1.6 and will be removed in v1.8."" Instead, you can use `trainer.save_checkpoint()` to manually save a checkpoint.")monitor_candidates=self._monitor_candidates(trainer)self._save_topk_checkpoint(trainer,monitor_candidates)self._save_last_checkpoint(trainer,monitor_candidates)
def_save_topk_checkpoint(self,trainer:"pl.Trainer",monitor_candidates:Dict[str,Tensor])->None:ifself.save_top_k==0:return# validate metricifself.monitorisnotNone:ifself.monitornotinmonitor_candidates:m=(f"`ModelCheckpoint(monitor={self.monitor!r})` could not find the monitored key in the returned"f" metrics: {list(monitor_candidates)}."f" HINT: Did you call `log({self.monitor!r}, value)` in the `LightningModule`?")iftrainer.fit_loop.epoch_loop.val_loop._has_run:raiseMisconfigurationException(m)warning_cache.warn(m)self._save_monitor_checkpoint(trainer,monitor_candidates)else:self._save_none_monitor_checkpoint(trainer,monitor_candidates)def_save_checkpoint(self,trainer:"pl.Trainer",filepath:str)->None:trainer.save_checkpoint(filepath,self.save_weights_only)self._last_global_step_saved=trainer.global_step# notify loggersiftrainer.is_global_zero:forloggerintrainer.loggers:logger.after_save_checkpoint(proxy(self))def_should_skip_saving_checkpoint(self,trainer:"pl.Trainer")->bool:frompytorch_lightning.trainer.statesimportTrainerFnreturn(bool(trainer.fast_dev_run)# disable checkpointing with fast_dev_runortrainer.state.fn!=TrainerFn.FITTING# don't save anything during non-fitortrainer.sanity_checking# don't save anything during sanity checkorself._last_global_step_saved==trainer.global_step# already saved at the last step)def__validate_init_configuration(self)->None:ifself.save_top_k<-1:raiseMisconfigurationException(f"Invalid value for save_top_k={self.save_top_k}. Must be >= -1")ifself._every_n_train_steps<0:raiseMisconfigurationException(f"Invalid value for every_n_train_steps={self._every_n_train_steps}. Must be >= 0")ifself._every_n_epochs<0:raiseMisconfigurationException(f"Invalid value for every_n_epochs={self._every_n_epochs}. Must be >= 0")every_n_train_steps_triggered=self._every_n_train_steps>=1every_n_epochs_triggered=self._every_n_epochs>=1train_time_interval_triggered=self._train_time_intervalisnotNoneifevery_n_train_steps_triggered+every_n_epochs_triggered+train_time_interval_triggered>1:raiseMisconfigurationException(f"Combination of parameters every_n_train_steps={self._every_n_train_steps}, "f"every_n_epochs={self._every_n_epochs} and train_time_interval={self._train_time_interval} ""should be mutually exclusive.")ifself.monitorisNone:# -1: save all epochs, 0: nothing is saved, 1: save last epochifself.save_top_knotin(-1,0,1):raiseMisconfigurationException(f"ModelCheckpoint(save_top_k={self.save_top_k}, monitor=None) is not a valid"" configuration. No quantity for top_k to track.")ifself.save_top_k==-1andself.save_last:rank_zero_info("ModelCheckpoint(save_last=True, save_top_k=-1, monitor=None)"" will duplicate the last checkpoint saved.")def__init_ckpt_dir(self,dirpath:Optional[_PATH],filename:Optional[str])->None:self._fs=get_filesystem(dirpathifdirpathelse"")ifdirpathandself._fs.protocol=="file":dirpath=os.path.realpath(dirpath)self.dirpath=dirpathself.filename=filenamedef__init_monitor_mode(self,mode:str)->None:torch_inf=torch.tensor(np.Inf)mode_dict={"min":(torch_inf,"min"),"max":(-torch_inf,"max")}ifmodenotinmode_dict:raiseMisconfigurationException(f"`mode` can be {', '.join(mode_dict.keys())} but got {mode}")self.kth_value,self.mode=mode_dict[mode]def__init_triggers(self,every_n_train_steps:Optional[int],every_n_epochs:Optional[int],train_time_interval:Optional[timedelta],)->None:# Default to running once after each validation epoch if neither# every_n_train_steps nor every_n_epochs is setifevery_n_train_stepsisNoneandevery_n_epochsisNoneandtrain_time_intervalisNone:every_n_epochs=1every_n_train_steps=0log.debug("Both every_n_train_steps and every_n_epochs are not set. Setting every_n_epochs=1")else:every_n_epochs=every_n_epochsor0every_n_train_steps=every_n_train_stepsor0self._train_time_interval:Optional[timedelta]=train_time_intervalself._every_n_epochs:int=every_n_epochsself._every_n_train_steps:int=every_n_train_steps@propertydefevery_n_epochs(self)->Optional[int]:returnself._every_n_epochsdefcheck_monitor_top_k(self,trainer:"pl.Trainer",current:Optional[Tensor]=None)->bool:ifcurrentisNone:returnFalseifself.save_top_k==-1:returnTrueless_than_k_models=len(self.best_k_models)<self.save_top_kifless_than_k_models:returnTruemonitor_op={"min":torch.lt,"max":torch.gt}[self.mode]should_update_best_and_save=monitor_op(current,self.best_k_models[self.kth_best_model_path])# If using multiple devices, make sure all processes are unanimous on the decision.should_update_best_and_save=trainer.strategy.reduce_boolean_decision(bool(should_update_best_and_save))returnshould_update_best_and_save@classmethoddef_format_checkpoint_name(cls,filename:Optional[str],metrics:Dict[str,Tensor],prefix:str="",auto_insert_metric_name:bool=True,)->str:ifnotfilename:# filename is not set, use default namefilename="{epoch}"+cls.CHECKPOINT_JOIN_CHAR+"{step}"# check and parse user passed keys in the stringgroups=re.findall(r"(\{.*?)[:\}]",filename)iflen(groups)>=0:forgroupingroups:name=group[1:]ifauto_insert_metric_name:filename=filename.replace(group,name+"={"+name)# support for dots: https://stackoverflow.com/a/7934969filename=filename.replace(group,f"{{0[{name}]")ifnamenotinmetrics:metrics[name]=torch.tensor(0)filename=filename.format(metrics)ifprefix:filename=cls.CHECKPOINT_JOIN_CHAR.join([prefix,filename])returnfilename
def__resolve_ckpt_dir(self,trainer:"pl.Trainer")->None:"""Determines model checkpoint save directory at runtime. Reference attributes from the trainer's logger to determine where to save checkpoints. The path for saving weights is set in this priority: 1. The ``ModelCheckpoint``'s ``dirpath`` if passed in 2. The ``Trainer``'s ``weights_saved_path`` if passed in (deprecated) 3. The ``Logger``'s ``log_dir`` if the trainer has loggers 4. The ``Trainer``'s ``default_root_dir`` if the trainer has no loggers The path gets extended with subdirectory "checkpoints". """ifself.dirpathisnotNone:# short circuit if dirpath was passed to ModelCheckpointreturn# TODO: Remove weights_save_path logic here in v1.8iftrainer._weights_save_path_internal!=trainer.default_root_dir:# the user has changed weights_save_pathckpt_path=os.path.join(trainer._weights_save_path_internal,"checkpoints")eliftrainer.loggers:iflen(trainer.loggers)==1:asserttrainer.loggerisnotNonesave_dir=trainer.logger.save_dirortrainer.default_root_direlse:save_dir=trainer.default_root_dirname=_name(trainer.loggers)version=_version(trainer.loggers)version=versionifisinstance(version,str)elsef"version_{version}"ckpt_path=os.path.join(save_dir,str(name),version,"checkpoints")else:# if no loggers, use default_root_dirckpt_path=os.path.join(trainer.default_root_dir,"checkpoints")ckpt_path=trainer.strategy.broadcast(ckpt_path)self.dirpath=ckpt_pathdef__warn_if_dir_not_empty(self,dirpath:_PATH)->None:ifself.save_top_k!=0andself._fs.isdir(dirpath)andlen(self._fs.ls(dirpath))>0:rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")def_get_metric_interpolated_filepath_name(self,monitor_candidates:Dict[str,Tensor],trainer:"pl.Trainer",del_filepath:Optional[str]=None)->str:filepath=self.format_checkpoint_name(monitor_candidates)version_cnt=self.STARTING_VERSIONwhileself.file_exists(filepath,trainer)andfilepath!=del_filepath:filepath=self.format_checkpoint_name(monitor_candidates,ver=version_cnt)version_cnt+=1returnfilepathdef_monitor_candidates(self,trainer:"pl.Trainer")->Dict[str,Tensor]:monitor_candidates=deepcopy(trainer.callback_metrics)# cast to int if necessary because `self.log("epoch", 123)` will convert it to float. if it's not a tensor# or does not exist we overwrite it as it's likely an errorepoch=monitor_candidates.get("epoch")monitor_candidates["epoch"]=epoch.int()ifisinstance(epoch,Tensor)elsetorch.tensor(trainer.current_epoch)step=monitor_candidates.get("step")monitor_candidates["step"]=step.int()ifisinstance(step,Tensor)elsetorch.tensor(trainer.global_step)returnmonitor_candidatesdef_save_last_checkpoint(self,trainer:"pl.Trainer",monitor_candidates:Dict[str,Tensor])->None:ifnotself.save_last:returnfilepath=self.format_checkpoint_name(monitor_candidates,self.CHECKPOINT_NAME_LAST)version_cnt=self.STARTING_VERSIONwhileself.file_exists(filepath,trainer)andfilepath!=self.last_model_path:filepath=self.format_checkpoint_name(monitor_candidates,self.CHECKPOINT_NAME_LAST,ver=version_cnt)version_cnt+=1# set the last model path before saving because it will be part of the state.previous,self.last_model_path=self.last_model_path,filepathself._save_checkpoint(trainer,filepath)ifpreviousandprevious!=filepath:trainer.strategy.remove_checkpoint(previous)def_save_monitor_checkpoint(self,trainer:"pl.Trainer",monitor_candidates:Dict[str,Tensor])->None:assertself.monitorcurrent=monitor_candidates.get(self.monitor)ifself.check_monitor_top_k(trainer,current):assertcurrentisnotNoneself._update_best_and_save(current,trainer,monitor_candidates)elifself.verbose:epoch=monitor_candidates["epoch"]step=monitor_candidates["step"]rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} was not in top {self.save_top_k}")def_save_none_monitor_checkpoint(self,trainer:"pl.Trainer",monitor_candidates:Dict[str,Tensor])->None:filepath=self._get_metric_interpolated_filepath_name(monitor_candidates,trainer)# set the best model path before saving because it will be part of the state.previous,self.best_model_path=self.best_model_path,filepathself._save_checkpoint(trainer,filepath)ifself.save_top_k==1andpreviousandprevious!=filepath:trainer.strategy.remove_checkpoint(previous)def_update_best_and_save(self,current:Tensor,trainer:"pl.Trainer",monitor_candidates:Dict[str,Tensor])->None:k=len(self.best_k_models)+1ifself.save_top_k==-1elseself.save_top_kdel_filepath=Noneiflen(self.best_k_models)==kandk>0:del_filepath=self.kth_best_model_pathself.best_k_models.pop(del_filepath)# do not save nan, replace with +/- infifisinstance(current,Tensor)andtorch.isnan(current):current=torch.tensor(float("inf"ifself.mode=="min"else"-inf"),device=current.device)filepath=self._get_metric_interpolated_filepath_name(monitor_candidates,trainer,del_filepath)# save the current scoreself.current_score=currentself.best_k_models[filepath]=currentiflen(self.best_k_models)==k:# monitor dict has reached k elements_op=maxifself.mode=="min"elseminself.kth_best_model_path=_op(self.best_k_models,key=self.best_k_models.get)# type: ignore[arg-type]self.kth_value=self.best_k_models[self.kth_best_model_path]_op=minifself.mode=="min"elsemaxself.best_model_path=_op(self.best_k_models,key=self.best_k_models.get)# type: ignore[arg-type]self.best_model_score=self.best_k_models[self.best_model_path]ifself.verbose:epoch=monitor_candidates["epoch"]step=monitor_candidates["step"]rank_zero_info(f"Epoch {epoch:d}, global step {step:d}: {self.monitor!r} reached {current:0.5f}"f" (best {self.best_model_score:0.5f}), saving model to {filepath!r} as top {k}")self._save_checkpoint(trainer,filepath)ifdel_filepathisnotNoneandfilepath!=del_filepath:trainer.strategy.remove_checkpoint(del_filepath)
[docs]defto_yaml(self,filepath:Optional[_PATH]=None)->None:"""Saves the `best_k_models` dict containing the checkpoint paths with the corresponding scores to a YAML file."""best_k={k:v.item()fork,vinself.best_k_models.items()}iffilepathisNone:assertself.dirpathfilepath=os.path.join(self.dirpath,"best_k_models.yaml")withself._fs.open(filepath,"w")asfp:yaml.dump(best_k,fp)
[docs]deffile_exists(self,filepath:_PATH,trainer:"pl.Trainer")->bool:"""Checks if a file exists on rank 0 and broadcasts the result to all other ranks, preventing the internal state to diverge between ranks."""exists=self._fs.exists(filepath)returntrainer.strategy.broadcast(exists)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.