[docs]classPredictionLoop(DataLoaderLoop):"""Loop to run over dataloaders for prediction."""def__init__(self)->None:super().__init__()self.predictions:List[List[Any]]=[]self.epoch_batch_indices:List[List[int]]=[]self.epoch_loop=PredictionEpochLoop()self._results=None# for `trainer._results` accessself._return_predictions:bool=False@propertydefreturn_predictions(self)->bool:"""Whether to return the predictions or not."""returnself._return_predictions@return_predictions.setterdefreturn_predictions(self,return_predictions:Optional[bool]=None)->None:# `DDPSpawnStrategy` plugins and derivatives don't support return predictions.is_ddp_spawn=isinstance(self.trainer.strategy,DDPSpawnStrategy)ifreturn_predictionsandis_ddp_spawn:raiseMisconfigurationException("`return_predictions` should be set to `False` when using the `DDPSpawnStrategy` or children class. "f"Found {return_predictions} with strategy {type(self.trainer.strategy)}.")# For non `DDPSpawnStrategy` plugin, the `return_predictions` is True by default unless user decide otherwise.self._return_predictions=notis_ddp_spawnifreturn_predictionsisNoneelsereturn_predictionsself.epoch_loop.return_predictions=self._return_predictions@propertydefnum_dataloaders(self)->int:"""Returns the number of prediction dataloaders."""# case where user does:# return dl1, dl2dataloaders=self.dataloaderslength=len(dataloaders)iflen(dataloaders)>0andisinstance(dataloaders[0],(list,tuple)):length=len(dataloaders[0])returnlength@propertydefmax_batches(self)->List[int]:"""The max number of batches this loop will run for each dataloader."""returnself.trainer.num_predict_batches@propertydefdataloaders(self)->Sequence[DataLoader]:"""Returns all prediction dataloaders."""dataloaders=self.trainer.predict_dataloadersreturn[]ifdataloadersisNoneelsedataloaders@propertydefskip(self)->bool:returnsum(self.max_batches)==0
[docs]defconnect(self,epoch_loop:PredictionEpochLoop)->None:# type: ignore[override]"""Connect the prediction epoch loop with this loop."""self.epoch_loop=epoch_loop
[docs]defreset(self)->None:"""Resets the internal state of the loop for a new run."""self.predictions=[]self.epoch_batch_indices=[]super().reset()# when restarting, if we are running twice, since there's no concept of `max_epochs` we need to reset the# current state when the loop has finished runningifself.done:self.dataloader_progress.reset_on_run()
[docs]defon_run_start(self)->None:# type: ignore[override]"""Calls ``_on_predict_model_eval``, ``_on_predict_start`` and ``_on_predict_epoch_start`` hooks."""self.trainer._call_lightning_module_hook("on_predict_model_eval")self.trainer.lightning_module.zero_grad()self._on_predict_start()self._on_predict_epoch_start()
[docs]defadvance(self,*args:Any,**kwargs:Any)->None:"""Predicts one entire dataloader."""void(*args,**kwargs)dataloader=self.current_dataloaderifdataloaderisnotNone:_set_sampler_epoch(dataloader,self.trainer.fit_loop.epoch_progress.current.processed)dataloader=self.trainer.strategy.process_dataloader(dataloader)dataloader_iter=enumerate(dataloader)dl_max_batches=self.max_batches[self.current_dataloader_idx]dl_predictions,dl_batch_indices=self.epoch_loop.run(dataloader_iter,self.current_dataloader_idx,dl_max_batches,self.num_dataloaders)self.predictions.append(dl_predictions)self.epoch_batch_indices.append(dl_batch_indices)
[docs]defon_run_end(self)->Optional[_PREDICT_OUTPUT]:"""Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders."""results=self._on_predict_epoch_end()self._on_predict_end()returnresults
def_on_predict_start(self)->None:"""Calls ``on_predict_start`` hooks."""self.trainer._call_callback_hooks("on_predict_start")self.trainer._call_lightning_module_hook("on_predict_start")self.trainer._call_strategy_hook("on_predict_start")def_on_predict_epoch_start(self)->None:"""Calls ``on_predict_epoch_start`` hooks."""self.trainer._call_callback_hooks("on_predict_epoch_start")self.trainer._call_lightning_module_hook("on_predict_epoch_start")def_on_predict_epoch_end(self)->Optional[_PREDICT_OUTPUT]:"""Calls ``on_predict_epoch_end`` hook. Returns: the results for all dataloaders """results=self.predictionsself.trainer._call_callback_hooks("on_predict_epoch_end",results)self.trainer._call_lightning_module_hook("on_predict_epoch_end",results)ifself.return_predictions:returnresults[0]ifself.num_dataloaders==1elseresultsdef_on_predict_end(self)->None:"""Resets previous gradient status and calls ``on_predict_end`` hook."""# clear memory. the predictions are extracted in `on_predict_epoch_end`.self.predictions=[]self.epoch_batch_indices=[]# hookself.trainer._call_callback_hooks("on_predict_end")self.trainer._call_lightning_module_hook("on_predict_end")self.trainer._call_strategy_hook("on_predict_end")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.