Source code for pytorch_lightning.loops.epoch.evaluation_epoch_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcollectionsimportOrderedDictfromfunctoolsimportlru_cachefromtypingimportAny,Dict,Optionalfromdeprecateimportvoidfromtorch.utils.dataimportDataLoaderfrompytorch_lightning.loops.baseimportLoopfrompytorch_lightning.trainer.progressimportBatchProgressfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.trainer.supportersimportCombinedLoaderfrompytorch_lightning.utilities.auto_restartimport(_collect_states_on_rank_zero_over_collection,_reload_dataloader_state_dict,)frompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.fetchingimportAbstractDataFetcher,DataLoaderIterDataFetcherfrompytorch_lightning.utilities.importsimport_fault_tolerant_trainingfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.typesimportEPOCH_OUTPUT,STEP_OUTPUT
[docs]classEvaluationEpochLoop(Loop):"""This is the loop performing the evaluation. It mainly loops over the given dataloader and runs the validation or test step (depending on the trainer's current state). """def__init__(self)->None:super().__init__()self.batch_progress=BatchProgress()self._outputs:EPOCH_OUTPUT=[]self._dl_max_batches=0self._data_fetcher:Optional[AbstractDataFetcher]=Noneself._dataloader_state_dict:Dict[str,Any]={}@propertydefdone(self)->bool:"""Returns ``True`` if the current iteration count reaches the number of dataloader batches."""returnself.batch_progress.current.completed>=self._dl_max_batches
[docs]defreset(self)->None:"""Resets the loop's internal state."""self._dl_max_batches=0self._data_fetcher=Noneself._outputs=[]ifnotself.restarting:self.batch_progress.reset_on_run()else:self.batch_progress.reset_on_restart()# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we# need to reset the current state when the loop has finished runningifself.doneandself.trainer.state.fn!=TrainerFn.FITTING:self.batch_progress.reset_on_run()
[docs]defon_run_start(# type: ignore[override]self,data_fetcher:AbstractDataFetcher,dl_max_batches:int,kwargs:OrderedDict)->None:"""Adds the passed arguments to the loop's state if necessary. Args: data_fetcher: the current data_fetcher wrapping the dataloader dl_max_batches: maximum number of batches the dataloader can produce kwargs: the kwargs passed down to the hooks. """void(kwargs)self._dl_max_batches=dl_max_batchesself._reload_dataloader_state_dict(data_fetcher)# creates the iterator inside the fetcher but returns `self`self._data_fetcher=iter(data_fetcher)# add the previous `fetched` value to properly track `is_last_batch` with no prefetchingdata_fetcher.fetched+=self.batch_progress.current.ready
[docs]defadvance(# type: ignore[override]self,data_fetcher:AbstractDataFetcher,dl_max_batches:int,kwargs:OrderedDict,)->None:"""Calls the evaluation step with the corresponding hooks and updates the logger connector. Args: data_fetcher: iterator over the dataloader dl_max_batches: maximum number of batches the dataloader can produce kwargs: the kwargs passed down to the hooks. Raises: StopIteration: If the current batch is None """void(dl_max_batches)ifnotisinstance(data_fetcher,DataLoaderIterDataFetcher):batch_idx=self.batch_progress.current.readybatch=next(data_fetcher)else:batch_idx,batch=next(data_fetcher)self.batch_progress.is_last_batch=data_fetcher.done# configure step_kwargskwargs=self._build_kwargs(kwargs,batch,batch_idx)self.batch_progress.increment_ready()# hookself._on_evaluation_batch_start(**kwargs)self.batch_progress.increment_started()# lightning module methodsoutput=self._evaluation_step(**kwargs)output=self._evaluation_step_end(output)self.batch_progress.increment_processed()# track loss historyself._on_evaluation_batch_end(output,**kwargs)self.batch_progress.increment_completed()# log batch metricsself.trainer._logger_connector.update_eval_step_metrics()# track epoch level outputsifself._should_track_batch_outputs_for_epoch_end()andoutputisnotNone:self._outputs.append(output)ifself.trainer.move_metrics_to_cpu:# the evaluation step output is not moved as they are not considered "metrics"assertself.trainer._resultsisnotNoneself.trainer._results.cpu()ifnotself.batch_progress.is_last_batch:# if fault tolerant is enabled and process has been notified, exit.self.trainer._exit_gracefully_on_signal()
[docs]defon_run_end(self)->EPOCH_OUTPUT:"""Returns the outputs of the whole run."""outputs,self._outputs=self._outputs,[]# free memoryself._data_fetcher=Nonereturnoutputs
[docs]defteardown(self)->None:# in case the model changesself._should_track_batch_outputs_for_epoch_end.cache_clear()
[docs]defon_save_checkpoint(self)->Dict:state_dict=super().on_save_checkpoint()if(self.trainerisnotNoneandself.trainer.state._fault_tolerant_mode.is_enabledandself._data_fetcherisnotNoneandnotself._num_completed_batches_reached()# did not finishandself.batch_progress.current.ready# did start):state=CombinedLoader._state_dict_fn(self._data_fetcher.dataloader_iter,self._has_completed())ifstate:state_dict["dataloader_state_dict"]=_collect_states_on_rank_zero_over_collection(state)returnstate_dict
[docs]defon_load_checkpoint(self,state_dict:Dict)->None:# cache the dataloader state dict until the dataloader objects are available# dataset states are collected across all ranksdataloader_state_dict=state_dict.get("dataloader_state_dict",None)ifnot_fault_tolerant_training()ornotdataloader_state_dict:returnself._dataloader_state_dict=dataloader_state_dict[self.trainer.global_rank]
def_reload_dataloader_state_dict(self,data_fetcher:AbstractDataFetcher)->None:ifself.trainer.sanity_checkingornotself._dataloader_state_dict:returndataloader=data_fetcher.dataloaderifisinstance(dataloader,CombinedLoader):raiseMisconfigurationException("Reloading support hasn't been implemented for `CombinedLoader`. You can request it by opening an issue"" in `https://github.com/PyTorchLightning/pytorch-lightning/issues`.")assertisinstance(dataloader,DataLoader)_reload_dataloader_state_dict(dataloader,self._dataloader_state_dict)self._dataloader_state_dict={}def_num_completed_batches_reached(self)->bool:epoch_finished_on_completed=self.batch_progress.current.completed==self._dl_max_batchesdataloader_consumed_successfully=self.batch_progress.is_last_batchandself._has_completed()returnepoch_finished_on_completedordataloader_consumed_successfullydef_has_completed(self)->bool:returnself.batch_progress.current.ready==self.batch_progress.current.completeddef_evaluation_step(self,**kwargs:Any)->Optional[STEP_OUTPUT]:"""The evaluation step (validation_step or test_step depending on the trainer's state). Args: batch: The current batch to run through the step. batch_idx: The index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the outputs of the step """ifself.trainer.testing:output=self.trainer._call_strategy_hook("test_step",*kwargs.values())else:output=self.trainer._call_strategy_hook("validation_step",*kwargs.values())returnoutputdef_evaluation_step_end(self,*args:Any,**kwargs:Any)->Optional[STEP_OUTPUT]:"""Calls the `{validation/test}_step_end` hook."""hook_name="test_step_end"ifself.trainer.testingelse"validation_step_end"model_output=self.trainer._call_lightning_module_hook(hook_name,*args,**kwargs)strategy_output=self.trainer._call_strategy_hook(hook_name,*args,**kwargs)output=strategy_outputifmodel_outputisNoneelsemodel_outputreturnoutputdef_on_evaluation_batch_start(self,**kwargs:Any)->None:"""Calls the ``on_{validation/test}_batch_start`` hook. Args: batch: The current batch to run through the step batch_idx: The index of the current batch dataloader_idx: The index of the dataloader producing the current batch Raises: AssertionError: If the number of dataloaders is None (has not yet been set). """self.trainer._logger_connector.on_batch_start(**kwargs)kwargs.setdefault("dataloader_idx",0)# TODO: the argument should be keyword for thesehook_name="on_test_batch_start"ifself.trainer.testingelse"on_validation_batch_start"self.trainer._call_callback_hooks(hook_name,*kwargs.values())self.trainer._call_lightning_module_hook(hook_name,*kwargs.values())def_on_evaluation_batch_end(self,output:Optional[STEP_OUTPUT],**kwargs:Any)->None:"""The ``on_{validation/test}_batch_end`` hook. Args: output: The output of the performed step batch: The input batch for the step batch_idx: The index of the current batch dataloader_idx: Index of the dataloader producing the current batch """kwargs.setdefault("dataloader_idx",0)# TODO: the argument should be keyword for thesehook_name="on_test_batch_end"ifself.trainer.testingelse"on_validation_batch_end"self.trainer._call_callback_hooks(hook_name,output,*kwargs.values())self.trainer._call_lightning_module_hook(hook_name,output,*kwargs.values())self.trainer._logger_connector.on_batch_end()def_build_kwargs(self,kwargs:OrderedDict,batch:Any,batch_idx:int)->OrderedDict:"""Helper function to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. batch: The current batch to run through the step. Returns: The kwargs passed down to the hooks. """kwargs.update({"batch":batch,"batch_idx":batch_idx})kwargs.move_to_end("batch_idx",last=False)kwargs.move_to_end("batch",last=False)returnkwargs@lru_cache(1)def_should_track_batch_outputs_for_epoch_end(self)->bool:"""Whether the batch outputs should be stored for later usage."""model=self.trainer.lightning_moduleifself.trainer.testing:returnis_overridden("test_epoch_end",model)returnis_overridden("validation_epoch_end",model)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.