# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importloggingimportosfromfunctoolsimportpartialfromtypingimportOptional,Typeimportpytorch_lightningasplfrompytorch_lightning.acceleratorsimportCUDAAcceleratorfrompytorch_lightning.loopsimportLoopfrompytorch_lightning.loops.epochimportTrainingEpochLoopfrompytorch_lightning.loops.epoch.training_epoch_loopimport_OUTPUTS_TYPEas_EPOCH_OUTPUTS_TYPEfrompytorch_lightning.loops.utilitiesimport_is_max_limit_reached,_set_sampler_epochfrompytorch_lightning.trainer.connectors.logger_connector.resultimport_ResultCollectionfrompytorch_lightning.trainer.progressimportProgressfrompytorch_lightning.trainer.supportersimportTensorRunningAccumfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.fetchingimport(AbstractDataFetcher,DataFetcher,DataLoaderIterDataFetcher,InterBatchParallelDataFetcher,)frompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimportrank_zero_debug,rank_zero_info,rank_zero_warnfrompytorch_lightning.utilities.signature_utilsimportis_param_in_hook_signaturelog=logging.getLogger(__name__)
[docs]classFitLoop(Loop[None]):"""This Loop iterates over the epochs to run the training. Args: min_epochs: The minimum number of epochs max_epochs: The maximum number of epochs, can be set -1 to turn this limit off """def__init__(self,min_epochs:int=0,max_epochs:Optional[int]=None,)->None:super().__init__()ifisinstance(max_epochs,int)andmax_epochs<-1:# Allow max_epochs to be zero, since this will be handled by fit_loop.doneraiseMisconfigurationException(f"`max_epochs` must be a non-negative integer or -1. You passed in {max_epochs}.")self.max_epochs=max_epochsself.min_epochs=min_epochsself.epoch_loop=TrainingEpochLoop()self.epoch_progress=Progress()self._is_fresh_start_epoch:bool=Trueself._outputs:_EPOCH_OUTPUTS_TYPE=[]self._data_fetcher:Optional[AbstractDataFetcher]=None@propertydeftotal_batch_idx(self)->int:"""Returns the current batch index (across epochs)"""returnself.epoch_loop.total_batch_idx@propertydefbatch_idx(self)->int:"""Returns the current batch index (within this epoch)"""returnself.epoch_loop.batch_idx@propertydefsplit_idx(self)->int:"""Returns the index of the current batch split (within the current batch) for bptt."""returnself.epoch_loop.batch_loop.split_idx@propertydefmin_steps(self)->Optional[int]:# TODO(@justusschock): Why aren't we using the attribute in this class?"""Returns the minimum number of steps to run."""returnself.epoch_loop.min_steps@min_steps.setterdefmin_steps(self,value:Optional[int])->None:"""Sets the minimum number of steps (forwards to epoch_loop)"""# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoidedself.epoch_loop.min_steps=value@propertydefmax_steps(self)->int:"""Returns the maximum number of steps to run."""returnself.epoch_loop.max_steps@max_steps.setterdefmax_steps(self,value:int)->None:"""Sets the maximum number of steps (forwards to epoch_loop)"""# TODO(@awaelchli): This setter is required by debugging connector (fast dev run), should be avoidedifvalue<-1:raiseMisconfigurationException(f"`max_steps` must be a non-negative integer or -1 (infinite steps). You passed in {value}.")self.epoch_loop.max_steps=value@propertydefrunning_loss(self)->TensorRunningAccum:"""Returns the running loss."""returnself.epoch_loop.batch_loop.running_loss@Loop.restarting.setterdefrestarting(self,restarting:bool)->None:# if the last epoch completely finished, we are not actually restartingvalues=self.epoch_progress.current.ready,self.epoch_progress.current.startedepoch_unfinished=any(v!=self.epoch_progress.current.processedforvinvalues)restarting=restartingandepoch_unfinishedorself._iteration_based_training()Loop.restarting.fset(self,restarting)# call the parent setter@propertydefprefetch_batches(self)->int:is_unsized=self.trainer.num_training_batches==float("inf")inter_batch_parallelism=os.getenv("PL_INTER_BATCH_PARALLELISM","0")=="1"return1ifis_unsizedorinter_batch_parallelismelse0@propertydef_skip_backward(self)->bool:"""Determines whether the loop will skip backward during automatic optimization."""returnself.epoch_loop.batch_loop.optimizer_loop._skip_backward@_skip_backward.setterdef_skip_backward(self,value:bool)->None:"""Determines whether the loop will skip backward during automatic optimization."""self.epoch_loop.batch_loop.optimizer_loop._skip_backward=value@propertydef_results(self)->_ResultCollection:ifself.trainer.training:returnself.epoch_loop._resultsifself.trainer.validating:returnself.epoch_loop.val_loop._resultsraiseRuntimeError("`FitLoop._results` property isn't defined. Accessed outside of scope")@propertydefdone(self)->bool:"""Evaluates when to leave the loop."""ifself.trainer.num_training_batches==0:rank_zero_info("`Trainer.fit` stopped: No training batches.")returnTrue# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loopstop_steps=_is_max_limit_reached(self.epoch_loop.global_step,self.max_steps)ifstop_steps:rank_zero_info(f"`Trainer.fit` stopped: `max_steps={self.max_steps!r}` reached.")returnTrue# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.# we use it here because the checkpoint data won't have `completed` increased yetassertisinstance(self.max_epochs,int)stop_epochs=_is_max_limit_reached(self.epoch_progress.current.processed,self.max_epochs)ifstop_epochs:# in case they are not equal, override so `trainer.current_epoch` has the expected valueself.epoch_progress.current.completed=self.epoch_progress.current.processedrank_zero_info(f"`Trainer.fit` stopped: `max_epochs={self.max_epochs!r}` reached.")returnTrueifself.trainer.should_stop:# early stoppingmet_min_epochs=self.epoch_progress.current.processed>=self.min_epochsifself.min_epochselseTruemet_min_steps=self.epoch_loop.global_step>=self.min_stepsifself.min_stepselseTrueifmet_min_epochsandmet_min_steps:self.trainer.should_stop=Truerank_zero_debug("`Trainer.fit` stopped: `trainer.should_stop` was set.")returnTrueelse:rank_zero_info(f"Trainer was signaled to stop but the required `min_epochs={self.min_epochs!r}` or"f" `min_steps={self.min_steps!r}` has not been met. Training will continue...")self.trainer.should_stop=FalsereturnFalse@propertydefskip(self)->bool:"""Whether we should skip the training and immediately return from the call to :meth:`run`."""# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called# until `on_run_start`, we use `limit_train_batches` insteadreturnself.doneorself.trainer.limit_train_batches==0
[docs]defconnect(self,epoch_loop:TrainingEpochLoop)->None:# type: ignore[override]"""Connects a training epoch loop to this fit loop."""self.epoch_loop=epoch_loop
[docs]defreset(self)->None:"""Resets the internal state of this loop."""ifself.restarting:self.epoch_progress.reset_on_restart()
[docs]defon_run_start(self)->None:# type: ignore[override]"""Calls the ``on_train_start`` hook."""# update the current_epoch in-case of checkpoint reloadifnotself._iteration_based_training():self.epoch_progress.current.completed=self.epoch_progress.current.processedself.trainer.reset_train_dataloader(self.trainer.lightning_module)# reload the evaluation dataloaders too for proper display in the progress barifself.epoch_loop._should_check_val_epoch():self.epoch_loop.val_loop._reload_evaluation_dataloaders()data_fetcher_cls=_select_data_fetcher(self.trainer)self._data_fetcher=data_fetcher_cls(prefetch_batches=self.prefetch_batches)self._is_fresh_start_epoch=Trueself._results.to(device=self.trainer.lightning_module.device)self.trainer._call_callback_hooks("on_train_start")self.trainer._call_lightning_module_hook("on_train_start")self.trainer._call_strategy_hook("on_train_start")
[docs]defon_advance_start(self)->None:# type: ignore[override]"""Prepares the dataloader for training and calls the hooks ``on_epoch_start`` and ``on_train_epoch_start``"""model=self.trainer.lightning_module# reset train dataloaderifnotself._is_fresh_start_epochandself.trainer._data_connector._should_reload_train_dl:log.detail(f"{self.__class__.__name__}: resetting train dataloader")self.trainer.reset_train_dataloader(model)self._is_fresh_start_epoch=False# reset outputs here instead of in `reset` as they are not accumulated between epochsself._outputs=[]ifself.trainer.train_dataloaderisnotNone:_set_sampler_epoch(self.trainer.train_dataloader,self.epoch_progress.current.processed)# changing gradient according accumulation_schedulerself.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer,self.trainer.lightning_module)# stores accumulated grad fractions per batchself.epoch_loop.batch_loop.accumulated_loss.reset(window_length=self.trainer.accumulate_grad_batches)self.epoch_progress.increment_ready()self.trainer._logger_connector.on_epoch_start()self.trainer._call_callback_hooks("on_epoch_start")self.trainer._call_lightning_module_hook("on_epoch_start")self.trainer._call_callback_hooks("on_train_epoch_start")self.trainer._call_lightning_module_hook("on_train_epoch_start")self.epoch_progress.increment_started()
[docs]defadvance(self)->None:# type: ignore[override]"""Runs one whole epoch."""log.detail(f"{self.__class__.__name__}: advancing loop")assertself.trainer.train_dataloaderisnotNonedataloader=self.trainer.train_dataloaderassertself._data_fetcherisnotNoneself._data_fetcher.setup(dataloader,batch_to_device=partial(self.trainer._call_strategy_hook,"batch_to_device",dataloader_idx=0))withself.trainer.profiler.profile("run_training_epoch"):self._outputs=self.epoch_loop.run(self._data_fetcher)
[docs]defon_advance_end(self)->None:# inform logger the batch loop has finishedself.trainer._logger_connector.epoch_end_reached()# get the model and call model.training_epoch_endmodel=self.trainer.lightning_moduleifis_overridden("training_epoch_end",model)andself._outputs:epoch_end_outputs=self.epoch_loop._prepare_outputs_training_epoch_end(self._outputs,lightning_module=model,num_optimizers=len(self.trainer.optimizers),)# run lightning module hook training_epoch_end# refresh the result for custom logging at the epoch levelepoch_end_outputs=self.trainer._call_lightning_module_hook("training_epoch_end",epoch_end_outputs)ifepoch_end_outputsisnotNone:raiseMisconfigurationException("`training_epoch_end` expects a return of None. ""HINT: remove the return statement in `training_epoch_end`.")# free memoryself._outputs=[]self.epoch_progress.increment_processed()# call train epoch end hooksself.trainer._call_callback_hooks("on_train_epoch_end")self.trainer._call_lightning_module_hook("on_train_epoch_end")self.trainer._call_callback_hooks("on_epoch_end")self.trainer._call_lightning_module_hook("on_epoch_end")self.trainer._logger_connector.on_epoch_end()ifself.epoch_loop._num_ready_batches_reached():self.epoch_loop.update_lr_schedulers("epoch",update_plateau_schedulers=True)# we manually decrease here because loggers expect that the same step is used when logging epoch-end metrics# even when the batch loop has finishedself.epoch_loop._batches_that_stepped-=1# log epoch metricsself.trainer._logger_connector.update_train_epoch_metrics()self.epoch_loop._batches_that_stepped+=1self.epoch_progress.increment_completed()# if fault tolerant is enabled and process has been notified, exit.self.trainer._exit_gracefully_on_signal()
[docs]defon_run_end(self)->None:"""Calls the ``on_train_end`` hook."""log.detail(f"{self.__class__.__name__}: train run ended")# hookself.trainer._call_callback_hooks("on_train_end")self.trainer._call_lightning_module_hook("on_train_end")self.trainer._call_strategy_hook("on_train_end")
def_should_accumulate(self)->bool:"""Whether the gradients should be accumulated."""returnself.epoch_loop._should_accumulate()def_iteration_based_training(self)->bool:returnself.trainer.max_steps!=-1
def_select_data_fetcher(trainer:"pl.Trainer")->Type[AbstractDataFetcher]:training_step_fx=getattr(trainer.lightning_module,"training_step")ifis_param_in_hook_signature(training_step_fx,"dataloader_iter",explicit=True):rank_zero_warn("Found `dataloader_iter` argument in the `training_step`. Note that the support for ""this signature is experimental and the behavior is subject to change.")returnDataLoaderIterDataFetcherelifos.getenv("PL_INTER_BATCH_PARALLELISM","0")=="1":ifnotisinstance(trainer.accelerator,CUDAAccelerator):raiseMisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")returnInterBatchParallelDataFetcherreturnDataFetcher
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.