Source code for pytorch_lightning.loops.epoch.evaluation_epoch_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcollectionsimportOrderedDictfromfunctoolsimportlru_cachefromtypingimportAny,Dict,Optionalfromdeprecateimportvoidfromtorch.utils.dataimportDataLoaderfrompytorch_lightning.loops.baseimportLoopfrompytorch_lightning.trainer.progressimportBatchProgressfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.trainer.supportersimportCombinedLoaderfrompytorch_lightning.utilities.auto_restartimport(_collect_states_on_rank_zero_over_collection,_reload_dataloader_state_dict,)frompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.fetchingimportAbstractDataFetcher,DataLoaderIterDataFetcherfrompytorch_lightning.utilities.importsimport_fault_tolerant_trainingfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.typesimportEPOCH_OUTPUT,STEP_OUTPUT
[docs]classEvaluationEpochLoop(Loop):"""This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current state). """def__init__(self)->None:super().__init__()self.batch_progress=BatchProgress()self._outputs:EPOCH_OUTPUT=[]self._dl_max_batches=0self._data_fetcher:Optional[AbstractDataFetcher]=Noneself._dataloader_state_dict:Dict[str,Any]={}self._dl_batch_idx=[0]@propertydefdone(self)->bool:"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""returnself.batch_progress.current.completed>=self._dl_max_batches
[docs]defreset(self)->None:"""Resets the loop's internal state."""self._dl_max_batches=0self._data_fetcher=Noneself._outputs=[]ifnotself.restarting:self.batch_progress.reset_on_run()else:self.batch_progress.reset_on_restart()# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we# need to reset the current state when the loop has finished runningifself.doneandself.trainer.state.fn!=TrainerFn.FITTING:self.batch_progress.reset_on_run()
[docs]defon_run_start(# type: ignore[override]self,data_fetcher:AbstractDataFetcher,dl_max_batches:int,kwargs:OrderedDict)->None:"""Adds the passed arguments to the loop's state if necessary. Args: data_fetcher: the current data_fetcher wrapping the dataloader dl_max_batches: maximum number of batches the dataloader can produce kwargs: the kwargs passed down to the hooks. """void(kwargs)self._dl_max_batches=dl_max_batchesself._reload_dataloader_state_dict(data_fetcher)# creates the iterator inside the fetcher but returns `self`self._data_fetcher=iter(data_fetcher)# add the previous `fetched` value to properly track `is_last_batch` with no prefetchingdata_fetcher.fetched+=self.batch_progress.current.ready
[docs]defadvance(# type: ignore[override]self,data_fetcher:AbstractDataFetcher,dl_max_batches:int,kwargs:OrderedDict,)->None:"""Calls the evaluation step with the corresponding hooks and updates the logger connector. Args: data_fetcher: iterator over the dataloader dl_max_batches: maximum number of batches the dataloader can produce kwargs: the kwargs passed down to the hooks. Raises: StopIteration: If the current batch is None """void(dl_max_batches)ifnotisinstance(data_fetcher,DataLoaderIterDataFetcher):batch_idx=self.batch_progress.current.readybatch=next(data_fetcher)else:batch_idx,batch=next(data_fetcher)self.batch_progress.is_last_batch=data_fetcher.done# configure step_kwargskwargs=self._build_kwargs(kwargs,batch,batch_idx)self.batch_progress.increment_ready()# hookself._on_evaluation_batch_start(**kwargs)self.batch_progress.increment_started()# lightning module methodsoutput=self._evaluation_step(**kwargs)output=self._evaluation_step_end(output)self.batch_progress.increment_processed()# track loss historyself._on_evaluation_batch_end(output,**kwargs)self.batch_progress.increment_completed()# log batch metricsifnotself.trainer.sanity_checking:dataloader_idx=kwargs.get("dataloader_idx",0)self.trainer._logger_connector.update_eval_step_metrics(self._dl_batch_idx[dataloader_idx])self._dl_batch_idx[dataloader_idx]+=1# track epoch level outputsifself._should_track_batch_outputs_for_epoch_end()andoutputisnotNone:self._outputs.append(output)ifself.trainer.move_metrics_to_cpu:# the evaluation step output is not moved as they are not considered "metrics"assertself.trainer._resultsisnotNoneself.trainer._results.cpu()ifnotself.batch_progress.is_last_batch:# if fault tolerant is enabled and process has been notified, exit.self.trainer._exit_gracefully_on_signal()
[docs]defon_run_end(self)->EPOCH_OUTPUT:"""Returns the outputs of the whole run."""outputs,self._outputs=self._outputs,[]# free memoryself._data_fetcher=Nonereturnoutputs
[docs]defteardown(self)->None:# in case the model changesself._should_track_batch_outputs_for_epoch_end.cache_clear()
[docs]defon_save_checkpoint(self)->Dict:state_dict=super().on_save_checkpoint()if(self.trainerisnotNoneandself.trainer.state._fault_tolerant_mode.is_enabledandself._data_fetcherisnotNoneandnotself._num_completed_batches_reached()# did not finishandself.batch_progress.current.ready# did start):state=CombinedLoader._state_dict_fn(self._data_fetcher.dataloader_iter,self._has_completed())ifstate:state_dict["dataloader_state_dict"]=_collect_states_on_rank_zero_over_collection(state)returnstate_dict
[docs]defon_load_checkpoint(self,state_dict:Dict)->None:# cache the dataloader state dict until the dataloader objects are available# dataset states are collected across all ranksdataloader_state_dict=state_dict.get("dataloader_state_dict",None)ifnot_fault_tolerant_training()ornotdataloader_state_dict:returnself._dataloader_state_dict=dataloader_state_dict[self.trainer.global_rank]
def_reload_dataloader_state_dict(self,data_fetcher:AbstractDataFetcher)->None:ifself.trainer.sanity_checkingornotself._dataloader_state_dict:returndataloader=data_fetcher.dataloaderifisinstance(dataloader,CombinedLoader):raiseMisconfigurationException("Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue"" in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.")assertisinstance(dataloader,DataLoader)_reload_dataloader_state_dict(dataloader,self._dataloader_state_dict)self._dataloader_state_dict={}def_num_completed_batches_reached(self)->bool:epoch_finished_on_completed=self.batch_progress.current.completed==self._dl_max_batchesdataloader_consumed_successfully=self.batch_progress.is_last_batchandself._has_completed()returnepoch_finished_on_completedordataloader_consumed_successfullydef_has_completed(self)->bool:returnself.batch_progress.current.ready==self.batch_progress.current.completeddef_evaluation_step(self,**kwargs:Any)->Optional[STEP_OUTPUT]:"""The evaluation step (validation_step or test_step depending on the trainer's state). Args: batch: The current batch to run through the step. batch_idx: The index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the outputs of the step """ifself.trainer.testing:output=self.trainer._call_strategy_hook("test_step",*kwargs.values())else:output=self.trainer._call_strategy_hook("validation_step",*kwargs.values())returnoutputdef_evaluation_step_end(self,*args:Any,**kwargs:Any)->Optional[STEP_OUTPUT]:"""Calls the `{validation/test}_step_end` hook."""hook_name="test_step_end"ifself.trainer.testingelse"validation_step_end"model_output=self.trainer._call_lightning_module_hook(hook_name,*args,**kwargs)strategy_output=self.trainer._call_strategy_hook(hook_name,*args,**kwargs)output=strategy_outputifmodel_outputisNoneelsemodel_outputreturnoutputdef_on_evaluation_batch_start(self,**kwargs:Any)->None:"""Calls the ``on_{validation/test}_batch_start`` hook. Args: batch: The current batch to run through the step batch_idx: The index of the current batch dataloader_idx: The index of the dataloader producing the current batch Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """self.trainer._logger_connector.on_batch_start(**kwargs)kwargs.setdefault("dataloader_idx",0)# TODO: the argument should be keyword for thesehook_name="on_test_batch_start"ifself.trainer.testingelse"on_validation_batch_start"self.trainer._call_callback_hooks(hook_name,*kwargs.values())self.trainer._call_lightning_module_hook(hook_name,*kwargs.values())def_on_evaluation_batch_end(self,output:Optional[STEP_OUTPUT],**kwargs:Any)->None:"""The ``on_{validation/test}_batch_end`` hook. Args: output: The output of the performed step batch: The input batch for the step batch_idx: The index of the current batch dataloader_idx: Index of the dataloader producing the current batch """kwargs.setdefault("dataloader_idx",0)# TODO: the argument should be keyword for thesehook_name="on_test_batch_end"ifself.trainer.testingelse"on_validation_batch_end"self.trainer._call_callback_hooks(hook_name,output,*kwargs.values())self.trainer._call_lightning_module_hook(hook_name,output,*kwargs.values())self.trainer._logger_connector.on_batch_end()def_build_kwargs(self,kwargs:OrderedDict,batch:Any,batch_idx:int)->OrderedDict:"""Helper function to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. batch: The current batch to run through the step. Returns: The kwargs passed down to the hooks. """kwargs.update({"batch":batch,"batch_idx":batch_idx})kwargs.move_to_end("batch_idx",last=False)kwargs.move_to_end("batch",last=False)returnkwargs@lru_cache(1)def_should_track_batch_outputs_for_epoch_end(self)->bool:"""Whether the batch outputs should be stored for later usage."""model=self.trainer.lightning_moduleifself.trainer.testing:returnis_overridden("test_epoch_end",model)returnis_overridden("validation_epoch_end",model)def_reset_dl_batch_idx(self,num_dataloaders:int)->None:self._dl_batch_idx=[0]*num_dataloaders
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.