Source code for pytorch_lightning.loops.dataloader.evaluation_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importosimportshutilimportsysfromcollectionsimportChainMap,OrderedDictfromfunctoolsimportpartialfromtypingimportAny,Iterable,List,Optional,Sequence,Tuple,Type,Unionfromdeprecate.utilsimportvoidfromtorchimportTensorfromtorch.utils.data.dataloaderimportDataLoaderimportpytorch_lightningasplfrompytorch_lightning.acceleratorsimportCUDAAcceleratorfrompytorch_lightning.callbacks.progress.rich_progressimport_RICH_AVAILABLEfrompytorch_lightning.loops.dataloaderimportDataLoaderLoopfrompytorch_lightning.loops.epochimportEvaluationEpochLoopfrompytorch_lightning.loops.utilitiesimport_set_sampler_epochfrompytorch_lightning.trainer.connectors.logger_connector.resultimport_OUT_DICT,_ResultCollectionfrompytorch_lightning.trainer.statesimportTrainerFnfrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.fetchingimport(AbstractDataFetcher,DataFetcher,DataLoaderIterDataFetcher,InterBatchParallelDataFetcher,)frompytorch_lightning.utilities.rank_zeroimportrank_zero_warnfrompytorch_lightning.utilities.signature_utilsimportis_param_in_hook_signaturefrompytorch_lightning.utilities.typesimportEPOCH_OUTPUTif_RICH_AVAILABLE:fromrichimportget_consolefromrich.tableimportColumn,Table
[docs]classEvaluationLoop(DataLoaderLoop):"""Loops over all dataloaders for evaluation."""def__init__(self,verbose:bool=True)->None:super().__init__()self.epoch_loop=EvaluationEpochLoop()self.verbose=verboseself._results=_ResultCollection(training=False)self._outputs:List[EPOCH_OUTPUT]=[]self._logged_outputs:List[_OUT_DICT]=[]self._max_batches:List[int]=[]self._has_run:bool=Falseself._data_fetcher:Optional[AbstractDataFetcher]=None@propertydefnum_dataloaders(self)->int:"""Returns the total number of dataloaders."""# case where user does:# return dl1, dl2dataloaders=self.dataloaderslength=len(dataloaders)iflength>0andisinstance(dataloaders[0],(list,tuple)):length=len(dataloaders[0])returnlength@propertydefdataloaders(self)->Sequence[DataLoader]:"""Returns the validation or test dataloaders."""dataloaders=self.trainer.test_dataloadersifself.trainer.testingelseself.trainer.val_dataloadersifdataloadersisNone:return[]returndataloaders@propertydefprefetch_batches(self)->int:batches=self.trainer.num_test_batchesifself.trainer.testingelseself.trainer.num_val_batchesis_unsized=batches[self.current_dataloader_idx]==float("inf")inter_batch_parallelism=os.getenv("PL_INTER_BATCH_PARALLELISM","0")=="1"return1ifis_unsizedorinter_batch_parallelismelse0
[docs]defconnect(self,epoch_loop:EvaluationEpochLoop)->None:# type: ignore[override]"""Connect the evaluation epoch loop with this loop."""self.epoch_loop=epoch_loop
@propertydefdone(self)->bool:"""Returns whether all dataloaders are processed or evaluation should be skipped altogether."""returnsuper().doneorself.skip@propertydefskip(self)->bool:"""Returns whether the evaluation should be skipped."""max_batches=self._get_max_batches()returnsum(max_batches)==0
[docs]defreset(self)->None:"""Resets the internal state of the loop."""self._max_batches=self._get_max_batches()# bookkeepingself._outputs=[]self._logged_outputs=[]ifisinstance(self._max_batches,int):self._max_batches=[self._max_batches]*len(self.dataloaders)super().reset()# when restarting, if we are running `validate` or `test` twice, since there's no concept of `max_epochs` we# need to reset the current state when the loop has finished runningifself.doneandself.trainer.state.fn!=TrainerFn.FITTING:self.dataloader_progress.reset_on_run()
[docs]defon_run_start(self,*args:Any,**kwargs:Any)->None:"""Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start`` hooks."""void(*args,**kwargs)data_fetcher_cls=_select_data_fetcher_type(self.trainer)self._data_fetcher=data_fetcher_cls(prefetch_batches=self.prefetch_batches)# hookself._on_evaluation_model_eval()self.trainer.lightning_module.zero_grad()self._on_evaluation_start()self._on_evaluation_epoch_start()
[docs]defadvance(self,*args:Any,**kwargs:Any)->None:"""Performs evaluation on one single dataloader."""void(*args,**kwargs)dataloader_idx=self.current_dataloader_idxdataloader=self.current_dataloaderassertself._data_fetcherisnotNoneself._data_fetcher.setup(dataloader,batch_to_device=partial(self.trainer._call_strategy_hook,"batch_to_device",dataloader_idx=dataloader_idx),)dl_max_batches=self._max_batches[dataloader_idx]kwargs=OrderedDict()ifself.num_dataloaders>1:kwargs["dataloader_idx"]=dataloader_idxdl_outputs=self.epoch_loop.run(self._data_fetcher,dl_max_batches,kwargs)# store batch level output per dataloaderself._outputs.append(dl_outputs)ifnotself.trainer.sanity_checking:# indicate the loop has runself._has_run=True
[docs]defon_run_end(self)->List[_OUT_DICT]:"""Runs the ``_on_evaluation_epoch_end`` hook."""# if `done` returned True before any iterations were done, this won't have been called in `on_advance_end`self.trainer._logger_connector.epoch_end_reached()# hookself._evaluation_epoch_end(self._outputs)self._outputs=[]# free memory# hookself._on_evaluation_epoch_end()logged_outputs,self._logged_outputs=self._logged_outputs,[]# free memory# include any logged outputs on epoch_endepoch_end_logged_outputs=self.trainer._logger_connector.update_eval_epoch_metrics()all_logged_outputs=dict(ChainMap(*logged_outputs))# list[dict] -> dictall_logged_outputs.update(epoch_end_logged_outputs)fordl_outputsinlogged_outputs:dl_outputs.update(epoch_end_logged_outputs)# log metricsself.trainer._logger_connector.log_eval_end_metrics(all_logged_outputs)# hookself._on_evaluation_end()# enable train mode againself._on_evaluation_model_train()ifself.verboseandself.trainer.is_global_zero:assertself.trainer.state.stageisnotNoneself._print_results(logged_outputs,self.trainer.state.stage)returnlogged_outputs
def_get_max_batches(self)->List[int]:"""Returns the max number of batches for each dataloader."""ifself.trainer.testing:max_batches=self.trainer.num_test_batcheselse:ifself.trainer.sanity_checking:max_batches=self.trainer.num_sanity_val_batcheselse:max_batches=self.trainer.num_val_batchesreturnmax_batchesdef_reload_evaluation_dataloaders(self)->None:"""Reloads dataloaders if necessary."""dataloaders=Noneifself.trainer.testing:self.trainer.reset_test_dataloader()dataloaders=self.trainer.test_dataloaderselifself.trainer.val_dataloadersisNoneorself.trainer._data_connector._should_reload_val_dl:self.trainer.reset_val_dataloader()dataloaders=self.trainer.val_dataloadersifdataloadersisnotNone:self.epoch_loop._reset_dl_batch_idx(len(dataloaders))def_on_evaluation_start(self,*args:Any,**kwargs:Any)->None:"""Runs ``on_{validation/test}_start`` hooks."""assertself._resultsisnotNoneself._results.to(device=self.trainer.lightning_module.device)hook_name="on_test_start"ifself.trainer.testingelse"on_validation_start"self.trainer._call_callback_hooks(hook_name,*args,**kwargs)self.trainer._call_lightning_module_hook(hook_name,*args,**kwargs)self.trainer._call_strategy_hook(hook_name,*args,**kwargs)def_on_evaluation_model_eval(self)->None:"""Sets model to eval mode."""hook_name="on_test_model_eval"ifself.trainer.testingelse"on_validation_model_eval"self.trainer._call_lightning_module_hook(hook_name)def_on_evaluation_model_train(self)->None:"""Sets model to train mode."""hook_name="on_test_model_train"ifself.trainer.testingelse"on_validation_model_train"self.trainer._call_lightning_module_hook(hook_name)def_on_evaluation_end(self,*args:Any,**kwargs:Any)->None:"""Runs ``on_{validation/test}_end`` hook."""hook_name="on_test_end"ifself.trainer.testingelse"on_validation_end"self.trainer._call_callback_hooks(hook_name,*args,**kwargs)self.trainer._call_lightning_module_hook(hook_name,*args,**kwargs)self.trainer._call_strategy_hook(hook_name,*args,**kwargs)# reset the logger connector stateself.trainer._logger_connector.reset_results()def_on_evaluation_epoch_start(self,*args:Any,**kwargs:Any)->None:"""Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""self.trainer._logger_connector.on_epoch_start()self.trainer._call_callback_hooks("on_epoch_start",*args,**kwargs)self.trainer._call_lightning_module_hook("on_epoch_start",*args,**kwargs)hook_name="on_test_epoch_start"ifself.trainer.testingelse"on_validation_epoch_start"self.trainer._call_callback_hooks(hook_name,*args,**kwargs)self.trainer._call_lightning_module_hook(hook_name,*args,**kwargs)def_evaluation_epoch_end(self,outputs:List[EPOCH_OUTPUT])->None:"""Runs ``{validation/test}_epoch_end``"""self.trainer._logger_connector._evaluation_epoch_end()# with a single dataloader don't pass a 2D listoutput_or_outputs:Union[EPOCH_OUTPUT,List[EPOCH_OUTPUT]]=(outputs[0]iflen(outputs)>0andself.num_dataloaders==1elseoutputs)# call the model epoch endhook_name="test_epoch_end"ifself.trainer.testingelse"validation_epoch_end"self.trainer._call_lightning_module_hook(hook_name,output_or_outputs)def_on_evaluation_epoch_end(self)->None:"""Runs ``on_{validation/test}_epoch_end`` hook."""hook_name="on_test_epoch_end"ifself.trainer.testingelse"on_validation_epoch_end"self.trainer._call_callback_hooks(hook_name)self.trainer._call_lightning_module_hook(hook_name)self.trainer._call_callback_hooks("on_epoch_end")self.trainer._call_lightning_module_hook("on_epoch_end")self.trainer._logger_connector.on_epoch_end()@staticmethoddef_get_keys(data:dict)->Iterable[Tuple[str,...]]:fork,vindata.items():ifisinstance(v,dict):fornew_keyinapply_to_collection(v,dict,EvaluationLoop._get_keys):yield(k,*new_key)# this need to be in parenthesis for older python versionselse:yieldk,@staticmethoddef_find_value(data:dict,target:Iterable[str])->Optional[Any]:target_start,*rest=targetiftarget_startnotindata:returnNoneresult=data[target_start]ifnotrest:returnresultreturnEvaluationLoop._find_value(result,rest)@staticmethoddef_print_results(results:List[_OUT_DICT],stage:str)->None:# remove the dl idx suffixresults=[{k.split("/dataloader_idx_")[0]:vfork,vinresult.items()}forresultinresults]metrics_paths={kforkeysinapply_to_collection(results,dict,EvaluationLoop._get_keys)forkinkeys}ifnotmetrics_paths:returnmetrics_strs=[":".join(metric)formetricinmetrics_paths]# sort both lists based on metrics_strsmetrics_strs,metrics_paths=zip(*sorted(zip(metrics_strs,metrics_paths)))headers=[f"DataLoader {i}"foriinrange(len(results))]# fallback is useful for testing of printed outputterm_size=shutil.get_terminal_size(fallback=(120,30)).columnsor120max_length=int(min(max(len(max(metrics_strs,key=len)),len(max(headers,key=len)),25),term_size/2))rows:List[List[Any]]=[[]for_inmetrics_paths]forresultinresults:formetric,rowinzip(metrics_paths,rows):val=EvaluationLoop._find_value(result,metric)ifvalisnotNone:ifisinstance(val,Tensor):val=val.item()ifval.numel()==1elseval.tolist()row.append(f"{val}")else:row.append(" ")# keep one column with max length for metricsnum_cols=int((term_size-max_length)/max_length)foriinrange(0,len(headers),num_cols):table_headers=headers[i:(i+num_cols)]table_rows=[row[i:(i+num_cols)]forrowinrows]table_headers.insert(0,f"{stage} Metric".capitalize())if_RICH_AVAILABLE:columns=[Column(h,justify="center",style="magenta",width=max_length)forhintable_headers]columns[0].style="cyan"table=Table(*columns)formetric,rowinzip(metrics_strs,table_rows):row.insert(0,metric)table.add_row(*row)console=get_console()console.print(table)else:row_format=f"{{:^{max_length}}}"*len(table_headers)half_term_size=int(term_size/2)try:# some terminals do not support this characterifsys.stdout.encodingisnotNone:"─".encode(sys.stdout.encoding)exceptUnicodeEncodeError:bar_character="-"else:bar_character="─"bar=bar_character*term_sizelines=[bar,row_format.format(*table_headers).rstrip(),bar]formetric,rowinzip(metrics_strs,table_rows):# deal with column overflowiflen(metric)>half_term_size:whilelen(metric)>half_term_size:row_metric=metric[:half_term_size]metric=metric[half_term_size:]lines.append(row_format.format(row_metric,*row).rstrip())lines.append(row_format.format(metric," ").rstrip())else:lines.append(row_format.format(metric,*row).rstrip())lines.append(bar)print(os.linesep.join(lines))
def_select_data_fetcher_type(trainer:"pl.Trainer")->Type[AbstractDataFetcher]:lightning_module=trainer.lightning_modulestep_fx_name="test_step"iftrainer.testingelse"validation_step"step_fx=getattr(lightning_module,step_fx_name)ifis_param_in_hook_signature(step_fx,"dataloader_iter",explicit=True):rank_zero_warn(f"Found `dataloader_iter` argument in the `{step_fx_name}`. Note that the support for ""this signature is experimental and the behavior is subject to change.")returnDataLoaderIterDataFetcherelifos.getenv("PL_INTER_BATCH_PARALLELISM","0")=="1":ifnotisinstance(trainer.accelerator,CUDAAccelerator):raiseMisconfigurationException("Inter batch parallelism is available only when using Nvidia GPUs.")returnInterBatchParallelDataFetcherreturnDataFetcher
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.