Source code for pytorch_lightning.loops.epoch.training_epoch_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importmathfromcollectionsimportdefaultdictfromtypingimportAny,Dict,Generator,List,Optional,overload,Tuple,Unionimportnumpyasnpimporttorchimportpytorch_lightningasplfrompytorch_lightningimportloops# import as loops to avoid circular importsfrompytorch_lightning.loops.batchimportTrainingBatchLoopfrompytorch_lightning.loops.batch.training_batch_loopimport_OUTPUTS_TYPEas_BATCH_OUTPUTS_TYPEfrompytorch_lightning.loops.utilitiesimport_get_active_optimizers,_is_max_limit_reached,_v1_8_output_formatfrompytorch_lightning.trainer.connectors.logger_connector.resultimport_ResultCollectionfrompytorch_lightning.trainer.progressimportBatchProgress,SchedulerProgressfrompytorch_lightning.trainer.supportersimportCombinedLoaderfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.auto_restartimport_collect_states_on_rank_zero_over_collectionfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.fetchingimportAbstractDataFetcher,DataLoaderIterDataFetcherfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_warnfrompytorch_lightning.utilities.signature_utilsimportis_param_in_hook_signaturefrompytorch_lightning.utilities.warningsimportWarningCache_OUTPUTS_TYPE=List[_BATCH_OUTPUTS_TYPE]
[docs]classTrainingEpochLoop(loops.Loop[_OUTPUTS_TYPE]):"""Runs over all batches in a dataloader (one epoch). Args: min_steps: The minimum number of steps (batches) to process max_steps: The maximum number of steps (batches) to process """def__init__(self,min_steps:Optional[int]=None,max_steps:int=-1)->None:super().__init__()ifmax_stepsisNone:rank_zero_deprecation("Setting `max_steps = None` is deprecated in v1.5 and will no longer be supported in v1.7."" Use `max_steps = -1` instead.")max_steps=-1elifmax_steps<-1:raiseMisconfigurationException(f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {max_steps}.")self.min_steps=min_stepsself.max_steps=max_stepsself.batch_progress=BatchProgress()self.scheduler_progress=SchedulerProgress()self.batch_loop=TrainingBatchLoop()self.val_loop=loops.EvaluationLoop(verbose=False)self._results=_ResultCollection(training=True)self._outputs:_OUTPUTS_TYPE=[]self._warning_cache=WarningCache()# caches the loaded dataloader state until dataloader objects are availableself._dataloader_state_dict:Dict[str,Any]={}self._batches_that_stepped:int=0@propertydeftotal_batch_idx(self)->int:"""Returns the current batch index (across epochs)"""# use `ready` instead of `completed` in case this is accessed after `completed` has been increased# but before the next `ready` increasereturnself.batch_progress.total.ready-1@propertydefbatch_idx(self)->int:"""Returns the current batch index (within this epoch)"""# use `ready` instead of `completed` in case this is accessed after `completed` has been increased# but before the next `ready` increasereturnself.batch_progress.current.ready-1@propertydefglobal_step(self)->int:lightning_module=self.trainer.lightning_moduleiflightning_moduleisNoneorlightning_module.automatic_optimization:returnself.batch_loop.optimizer_loop.optim_progress.optimizer_stepsreturnself.batch_loop.manual_loop.optim_step_progress.total.completed@propertydef_is_training_done(self)->bool:max_steps_reached=_is_max_limit_reached(self.global_step,self.max_steps)returnmax_steps_reachedorself._num_ready_batches_reached()@propertydef_is_validation_done(self)->bool:# when we are restarting we want to check whether the val loop has finishedreturnnotself.restartingorself.val_loop.done@propertydefdone(self)->bool:"""Evaluates when to leave the loop."""return(self._is_training_doneandself._is_validation_done)orself.trainer.should_stop
[docs]defconnect(# type: ignore[override]self,batch_loop:Optional[TrainingBatchLoop]=None,val_loop:Optional["loops.EvaluationLoop"]=None,)->None:"""Optionally connect a custom batch or validation loop to this training epoch loop."""ifbatch_loopisnotNone:self.batch_loop=batch_loopifval_loopisnotNone:self.val_loop=val_loop
[docs]defreset(self)->None:"""Resets the internal state of the loop for a new run."""ifself.restarting:self.batch_progress.reset_on_restart()self.scheduler_progress.reset_on_restart()self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()trainer=self.trainerifnottrainer.state._fault_tolerant_mode.is_enabledandtrainer.num_training_batches!=float("inf"):expected_steps=math.ceil(trainer.num_training_batches/trainer.accumulate_grad_batches)ifself.global_step%expected_steps!=0:rank_zero_warn("You're resuming from a checkpoint that ended before the epoch ended. This can cause unreliable"" results if further training is done. Consider using an end-of-epoch checkpoint or enabling"" fault-tolerant training:"" https://pytorch-lightning.readthedocs.io/en/stable/advanced/fault_tolerant_training.html")else:self.batch_progress.reset_on_run()self.scheduler_progress.reset_on_run()self.batch_loop.optimizer_loop.optim_progress.reset_on_run()# when the epoch starts, the total val batch progress should be reset as it's supposed to count the batches# seen per epoch, this is useful for tracking when validation is run multiple times per epochself.val_loop.epoch_loop.batch_progress.total.reset()self._outputs=[]
[docs]defon_run_start(self,data_fetcher:AbstractDataFetcher)->None:# type: ignore[override]self._reload_dataloader_state_dict(data_fetcher)_=iter(data_fetcher)# creates the iterator inside the fetcher# add the previous `fetched` value to properly track `is_last_batch` with no prefetchingdata_fetcher.fetched+=self.batch_progress.current.ready
[docs]defadvance(self,data_fetcher:AbstractDataFetcher)->None:# type: ignore[override]"""Runs a single training batch. Raises: StopIteration: When the epoch is canceled by the user returning -1 """ifself.restartingandself._should_check_val_fx(self.batch_idx,self.batch_progress.is_last_batch):# skip training and run validation in `on_advance_end`return# we are going to train first so the val loop does not need to restartself.val_loop.restarting=Falseifnotisinstance(data_fetcher,DataLoaderIterDataFetcher):batch_idx=self.batch_idx+1batch=next(data_fetcher)else:batch_idx,batch=next(data_fetcher)self.batch_progress.is_last_batch=data_fetcher.doneself.batch_progress.increment_ready()self.trainer._logger_connector.on_batch_start(batch,batch_idx)ifbatchisNone:self._warning_cache.warn("train_dataloader yielded None. If this was on purpose, ignore this warning...")batch_output=[]else:# hookself.trainer._call_callback_hooks("on_batch_start")# TODO: Update this in v1.7 (deprecation: #9816)model_fx=self.trainer.lightning_module.on_train_batch_startextra_kwargs=({"dataloader_idx":0}ifcallable(model_fx)andis_param_in_hook_signature(model_fx,"dataloader_idx",explicit=True)else{})# hookself.trainer._call_callback_hooks("on_train_batch_start",batch,batch_idx,**extra_kwargs)response=self.trainer._call_lightning_module_hook("on_train_batch_start",batch,batch_idx,**extra_kwargs)self.trainer._call_strategy_hook("on_train_batch_start",batch,batch_idx,**extra_kwargs)ifresponse==-1:self.batch_progress.increment_processed()raiseStopIterationself.batch_progress.increment_started()withself.trainer.profiler.profile("run_training_batch"):batch_output=self.batch_loop.run(batch,batch_idx)self.batch_progress.increment_processed()# update non-plateau LR schedulers# update epoch-interval ones only when we are at the end of training epochself.update_lr_schedulers("step",update_plateau_schedulers=False)ifself._num_ready_batches_reached():self.update_lr_schedulers("epoch",update_plateau_schedulers=False)batch_end_outputs=self._prepare_outputs_training_batch_end(batch_output,lightning_module=self.trainer.lightning_module,num_optimizers=len(self.trainer.optimizers),)# TODO: Update this in v1.7 (deprecation: #9816)model_fx=self.trainer.lightning_module.on_train_batch_endextra_kwargs=({"dataloader_idx":0}ifcallable(model_fx)andis_param_in_hook_signature(model_fx,"dataloader_idx",explicit=True)else{})self.trainer._call_callback_hooks("on_train_batch_end",batch_end_outputs,batch,batch_idx,**extra_kwargs)self.trainer._call_lightning_module_hook("on_train_batch_end",batch_end_outputs,batch,batch_idx,**extra_kwargs)self.trainer._call_callback_hooks("on_batch_end")self.trainer._logger_connector.on_batch_end()self.batch_progress.increment_completed()ifis_overridden("training_epoch_end",self.trainer.lightning_module):self._outputs.append(batch_output)# -----------------------------------------# SAVE METRICS TO LOGGERS AND PROGRESS_BAR# -----------------------------------------self.trainer._logger_connector.update_train_step_metrics()
[docs]defon_advance_end(self)->None:# -----------------------------------------# VALIDATE IF NEEDED# -----------------------------------------should_check_val=self._should_check_val_fx(self.batch_idx,self.batch_progress.is_last_batch)ifshould_check_val:self.trainer.validating=Trueself._run_validation()self.trainer.training=True# update plateau LR scheduler after metrics are loggedself.update_lr_schedulers("step",update_plateau_schedulers=True)ifnotself._should_accumulate():# this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggersself._batches_that_stepped+=1# this will save based on the `batches_that_stepped` valueself._save_loggers_on_train_batch_end()# if training finished, defer exit to the parent. this assumes there will be enough time in between# which might not be the case depending on what's in the `*_epoch_end` hooksifnotself._is_training_done:# if fault tolerant is enabled and process has been notified, exit.self.trainer._exit_gracefully_on_signal()
[docs]defon_save_checkpoint(self)->Dict:state_dict=super().on_save_checkpoint()if(self.trainerisnotNoneandself.trainer.state._fault_tolerant_mode.is_enabledandself.trainer.train_dataloaderisnotNoneandnotself._num_completed_batches_reached()# did not finish# TODO: fault-tolerance requires a minimum number of batches so probably should be > 0andself.batch_progress.current.ready# did start):loader:CombinedLoader=self.trainer.train_dataloaderstate=loader.state_dict(has_completed=self._has_completed())ifstate:state_dict["dataloader_state_dict"]=_collect_states_on_rank_zero_over_collection(state)returnstate_dict
[docs]defon_load_checkpoint(self,state_dict:Dict)->None:# cache the dataloader state dict until the dataloader objects are availableself._dataloader_state_dict=state_dict.get("dataloader_state_dict")
def_run_validation(self)->None:# reload dataloadersself.val_loop._reload_evaluation_dataloaders()withtorch.no_grad():self.val_loop.run()def_accumulated_batches_reached(self)->bool:"""Determine if accumulation will be finished by the end of the current batch."""returnself.batch_progress.current.ready%self.trainer.accumulate_grad_batches==0def_num_ready_batches_reached(self)->bool:"""Checks if we are in the last batch or if there are more batches to follow."""epoch_finished_on_ready=self.batch_progress.current.ready==self.trainer.num_training_batchesreturnepoch_finished_on_readyorself.batch_progress.is_last_batchdef_num_completed_batches_reached(self)->bool:epoch_finished_on_completed=self.batch_progress.current.completed==self.trainer.num_training_batchesdataloader_consumed_successfully=self.batch_progress.is_last_batchandself._has_completed()returnepoch_finished_on_completedordataloader_consumed_successfullydef_has_completed(self)->bool:returnself.batch_progress.current.ready==self.batch_progress.current.completeddef_should_accumulate(self)->bool:"""Checks if the optimizer step should be performed or gradients should be accumulated for the current step."""accumulation_done=self._accumulated_batches_reached()# Lightning steps on the final batchis_final_batch=self._num_ready_batches_reached()# but the strategy might notstrategy_accumulates_on_final_batch=self.trainer.strategy.handles_gradient_accumulationornotis_final_batchreturnnotaccumulation_doneandstrategy_accumulates_on_final_batch@staticmethoddef_prepare_outputs_training_batch_end(batch_output:_BATCH_OUTPUTS_TYPE,lightning_module:"pl.LightningModule",num_optimizers:int,)->Union[List[List[Dict[str,Any]]],List[Dict[str,Any]]]:"""Processes the outputs from the batch loop into the format passed to the ``on_train_batch_end`` hook."""ifnotbatch_output:return[]# convert optimizer dicts to listiflightning_module.automatic_optimization:batch_output=apply_to_collection(batch_output,dtype=dict,function=_convert_optim_dict,num_optimizers=num_optimizers)array=np.array(batch_output,dtype=object)# TODO: remove in v1.8if(num_optimizers>1andlightning_module.truncated_bptt_steps>0andnot_v1_8_output_format(lightning_module.on_train_batch_end)):rank_zero_deprecation("You are training with multiple optimizers AND truncated backpropagation through time enabled."" The current format of the `on_train_batch_end(outputs, ...)` is a 2d list with sizes"" (n_optimizers, tbptt_steps), however, this has been deprecated and will change in version v1.8 to"" (tbptt_steps, n_optimizers). You can update your code by adding the following parameter to your"" hook signature: `on_train_batch_end(outputs, ..., new_format=True)`.")# (tbptt_steps, n_opt) -> (n_opt, tbptt_steps)ifarray.ndim==1:array=np.expand_dims(array,1)array=array.transpose((1,0))# squeeze all single-element dimensionsarray=array.squeeze()array=array.tolist()array=_recursive_unpad(array)returnarray@staticmethoddef_prepare_outputs_training_epoch_end(batch_outputs:_OUTPUTS_TYPE,lightning_module:"pl.LightningModule",num_optimizers:int,)->Union[List[List[List[Dict[str,Any]]]],List[List[Dict[str,Any]]],List[Dict[str,Any]]]:"""Processes the outputs from the batch loop into the format passed to the ``training_epoch_end`` hook."""# `batch_outputs` (plural) is the same as `epoch_end_output` (singular)ifnotbatch_outputs:return[]# convert optimizer dicts to listiflightning_module.automatic_optimization:batch_outputs=apply_to_collection(batch_outputs,dtype=dict,function=_convert_optim_dict,num_optimizers=num_optimizers)array=_recursive_pad(batch_outputs)# TODO: remove in v1.8if(num_optimizers>1andlightning_module.truncated_bptt_steps>0andnot_v1_8_output_format(lightning_module.on_train_epoch_end)):rank_zero_deprecation("You are training with multiple optimizers AND truncated backpropagation through time enabled."" The current format of the `training_epoch_end(outputs)` is a 3d list with sizes"" (n_optimizers, n_batches, tbptt_steps), however, this has been deprecated and will change in version"" v1.8 to (n_batches, tbptt_steps, n_optimizers). You can update your code by adding the following"" parameter to your hook signature: `training_epoch_end(outputs, new_format=True)`.")# (n_batches, tbptt_steps, n_opt) -> (n_opt, n_batches, tbptt_steps)ifarray.ndim==2:array=np.expand_dims(array,2)array=array.transpose((2,0,1))# squeeze all single-element dimensionsarray=array.squeeze()array=array.tolist()array=_recursive_unpad(array)# in case we squeezed from 1-element array to a 0-dim arrayarray=arrayifisinstance(array,list)else[array]# remove residual empty listsarray=[itemforiteminarrayifnotisinstance(item,list)orlen(item)]returnarray
[docs]defupdate_lr_schedulers(self,interval:str,update_plateau_schedulers:bool)->None:"""updates the lr schedulers based on the given interval."""ifinterval=="step"andself._should_accumulate():returnactive_optimizers=_get_active_optimizers(self.trainer.optimizers,self.trainer.optimizer_frequencies,self.total_batch_idx)self._update_learning_rates(interval=interval,update_plateau_schedulers=update_plateau_schedulers,opt_indices=[opt_idxforopt_idx,_inactive_optimizers],)
def_update_learning_rates(self,interval:str,update_plateau_schedulers:bool,opt_indices:Optional[List[int]]=None)->None:"""Update learning rates. Args: interval: either 'epoch' or 'step'. update_plateau_schedulers: control whether ``ReduceLROnPlateau`` or non-plateau schedulers get updated. This is used so non-plateau schedulers can be updated before running validation. Checkpoints are commonly saved during validation, however, on-plateau schedulers might monitor a validation metric so they have to be updated separately. opt_indices: indices of the optimizers to update. """ifnotself.trainer.lr_scheduler_configsornotself.trainer.lightning_module.automatic_optimization:returnifopt_indicesisNone:opt_indices=[]forconfiginself.trainer.lr_scheduler_configs:ifconfig.opt_idxnotinopt_indices:continueifupdate_plateau_schedulers^config.reduce_on_plateau:continuecurrent_idx=self.batch_idxifinterval=="step"elseself.trainer.current_epochcurrent_idx+=1# account for both batch and epoch starts from 0# Take step if call to update_learning_rates matches the interval key and# the current step modulo the schedulers frequency is zeroifconfig.interval==intervalandcurrent_idx%config.frequency==0:monitor_val=Noneifconfig.reduce_on_plateau:# If instance of ReduceLROnPlateau, we need a monitormonitor_key=config.monitormonitor_val=self._get_monitor_value(monitor_key)ifmonitor_valisNone:ifconfig.strict:avail_metrics=list(self.trainer.callback_metrics)raiseMisconfigurationException(f"ReduceLROnPlateau conditioned on metric {monitor_key}"f" which is not available. Available metrics are: {avail_metrics}."" Condition can be set using `monitor` key in lr scheduler dict")rank_zero_warn(f"ReduceLROnPlateau conditioned on metric {monitor_key}"" which is not available but strict is set to `False`."" Skipping learning rate update.",category=RuntimeWarning,)continueself.scheduler_progress.increment_ready()# update LRself.trainer._call_lightning_module_hook("lr_scheduler_step",config.scheduler,config.opt_idx,monitor_val,)self.scheduler_progress.increment_completed()def_get_monitor_value(self,key:str)->Any:# this is a separate method to aid in testingreturnself.trainer.callback_metrics.get(key)def_should_check_val_epoch(self):return(self.trainer.enable_validationand(self.trainer.current_epoch+1)%self.trainer.check_val_every_n_epoch==0)def_should_check_val_fx(self,batch_idx:int,is_last_batch:bool)->bool:"""Decide if we should run validation."""ifnotself._should_check_val_epoch():returnFalse# val_check_batch is inf for iterable datasets with no length definedis_infinite_dataset=self.trainer.val_check_batch==float("inf")ifis_last_batchandis_infinite_dataset:returnTrueifself.trainer.should_stop:returnTrue# TODO(@awaelchli): let training/eval loop handle logic around limit_*_batches and val_check_batchis_val_check_batch=is_last_batchifisinstance(self.trainer.limit_train_batches,int)andis_infinite_dataset:is_val_check_batch=(batch_idx+1)%self.trainer.limit_train_batches==0elifself.trainer.val_check_batch!=float("inf"):is_val_check_batch=(batch_idx+1)%self.trainer.val_check_batch==0returnis_val_check_batchdef_save_loggers_on_train_batch_end(self)->None:"""Flushes loggers to disk."""# this assumes that `batches_that_stepped` was increased beforeshould_flush=self._batches_that_stepped%self.trainer.flush_logs_every_n_steps==0ifshould_flushorself.trainer.should_stop:forloggerinself.trainer.loggers:logger.save()def_reload_dataloader_state_dict(self,data_fetcher:AbstractDataFetcher)->None:ifself._dataloader_state_dict:data_fetcher.dataloader.load_state_dict(self._dataloader_state_dict)self._dataloader_state_dict=None
def_convert_optim_dict(outs:Dict[int,Dict[str,Any]],num_optimizers:int)->List[Optional[Dict[str,Any]]]:"""Converts an optimizer dict to a list in which the key of the dict determines the position of the element. Example:: >>> _convert_optim_dict({0: {"loss": 0.0}, 2: {"loss": 0.2}}, num_optimizers=3) [{'loss': 0.0}, None, {'loss': 0.2}] """return[outs[opt_idx]ifopt_idxinoutselseNoneforopt_idxinrange(num_optimizers)]@overloaddef_recursive_unpad(nested:Any,value:Optional[Any]=None)->Any:...@overloaddef_recursive_unpad(nested:List[Any],value:Optional[Any]=None)->List[Any]:...def_recursive_unpad(nested:Union[Any,List[Any]],value:Optional[Any]=None)->Union[Any,List[Any]]:"""Removes the given pad value from the nested list. Not strictly the reverse operation of :func:`_recursive_pad` because it removes the padding element everywhere, not just from the end of a list. Example:: >>> _recursive_unpad([[[0, 1, 0]], [2], [0, 0]], value=0) [[[1]], [2], []] """ifnotisinstance(nested,list):returnnestedreturn[_recursive_unpad(item,value)foriteminnestedifitem!=value]def_recursive_pad(nested:List[Any],fill_value:Optional[Any]=None)->np.array:"""Pads a jagged nested list of lists with the given value such that a proper multi-dimensional array can be formed with rectangular shape. The padding appends to the incomplete lists. Example:: >>> _recursive_pad([[], [1], [2, 3], [4]], fill_value=0) # doctest: +NORMALIZE_WHITESPACE array([[0, 0], [1, 0], [2, 3], [4, 0]], dtype=object) """# code adapted from stackexchange:# https://codereview.stackexchange.com/questions/222623/pad-a-ragged-multidimensional-array-to-rectangular-shapedimensions=_get_max_shape(nested)result=np.full(dimensions,fill_value,dtype=object)forindex,valuein_iterate_nested_array(nested):result[index]=valuereturnresultdef_get_dimensions(array:List[Any],level:int=0)->Generator:yieldlevel,len(array)ifall(isinstance(row,list)forrowinarray):forrowinarray:yield from_get_dimensions(row,level+1)def_get_max_shape(array:List[Any])->List[int]:"""Calculates the max size in each dimension of a jagged (non-rectangular) nested list of lists. Example:: >>> _get_max_shape([[], [[1], [2]], []]) [3, 2, 1] """dimensions=defaultdict(int)forlevel,lengthin_get_dimensions(array):dimensions[level]=max(dimensions[level],length)return[valuefor_,valueinsorted(dimensions.items())]def_iterate_nested_array(array:List[Any],index:Tuple=())->Generator:ifall(isinstance(item,list)foriteminarray):foridx,rowinenumerate(array):yield from_iterate_nested_array(row,(*index,idx))else:# final levelyield(*index,slice(len(array))),array
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.