[docs]classPredictionEpochLoop(Loop):"""Loop performing prediction on arbitrary sequentially used dataloaders."""def__init__(self)->None:super().__init__()self.return_predictions=Falseself.predictions:List[Any]=[]self.current_batch_indices:List[int]=[]self.batch_progress=Progress()self._dl_max_batches=0self._num_dataloaders=0self._warning_cache=WarningCache()self._seen_batch_indices:List[List[int]]=[]@propertydefdone(self)->bool:"""Ends prediction when the iteration count exceeds the total number of available batches."""returnself.batch_progress.current.completed>=self._dl_max_batches@propertydefshould_store_predictions(self)->bool:"""Whether the predictions should be stored for later usage (e.g. aggregation or returning)"""any_pred=any(cb.interval.on_epochforcbinself.trainer.prediction_writer_callbacks)returnself.return_predictionsorany_pred
[docs]defconnect(self,**kwargs:"Loop")->None:raiseNotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
[docs]defreset(self)->None:"""Resets the loops internal state."""self._seen_batch_indices=[]self.predictions=[]self.batch_progress.reset_on_run()
[docs]defon_run_start(# type: ignore[override]self,dataloader_iter:Iterator,dataloader_idx:int,dl_max_batches:int,num_dataloaders:int,)->None:"""Prepares the loops internal state. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders """void(dataloader_iter,dataloader_idx)self._dl_max_batches=dl_max_batchesself._num_dataloaders=num_dataloaders# this call requires that `self.return_predictions` is setself._seen_batch_indices=self._get_batch_indices(dataloader_idx)ifself.should_store_predictionselse[]
[docs]defadvance(# type: ignore[override]self,dataloader_iter:Iterator,dataloader_idx:int,dl_max_batches:int,num_dataloaders:int,)->None:"""Runs one prediction step. Args: dataloader_iter: the iterator over the current dataloader dataloader_idx: the index of the current dataloader dl_max_batches: the maximum number of batches the current loader can produce num_dataloaders: the total number of dataloaders """action_name=f"[{self.__class__.__name__}].predict_dataloader_idx_{dataloader_idx}_next"withself.trainer.profiler.profile(action_name):batch_idx,batch=next(dataloader_iter)self._seen_batch_indices=self._get_batch_indices(dataloader_idx)ifself.should_store_predictionselse[]# we need to truncate the list of batch indices due to prefetching in the dataloader and Lightningself._seen_batch_indices=self._seen_batch_indices[:(self.batch_progress.current.completed+1)]ifbatchisNone:raiseStopIterationbatch=self.trainer._call_strategy_hook("batch_to_device",batch,dataloader_idx=dataloader_idx)self.batch_progress.increment_ready()self._predict_step(batch,batch_idx,dataloader_idx)
[docs]defon_run_end(self)->Tuple[List[Any],List[List[int]]]:"""Returns the predictions and the corresponding batch indices."""predictions,all_batch_indices=self.predictions,self._seen_batch_indicesself.predictions,self._seen_batch_indices=[],[]# free memoryreturnpredictions,all_batch_indices
def_predict_step(self,batch:Any,batch_idx:int,dataloader_idx:int)->None:"""Runs the actual predict step together with all the necessary bookkeeping and the hooks tied to the predict step. Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch """# configure step_kwargsstep_kwargs=self._build_kwargs(batch,batch_idx,dataloader_idx)# extract batch_indices and store thembatch_indices=self._get_batch_indices(dataloader_idx)self.current_batch_indices=batch_indices[batch_idx]ifbatch_indiceselse[]self.trainer._call_callback_hooks("on_predict_batch_start",batch,batch_idx,dataloader_idx)self.trainer._call_lightning_module_hook("on_predict_batch_start",batch,batch_idx,dataloader_idx)self.batch_progress.increment_started()predictions=self.trainer._call_strategy_hook("predict_step",*step_kwargs.values())self.batch_progress.increment_processed()ifpredictionsisNone:self._warning_cache.warn("predict returned None if it was on purpose, ignore this warning...")self.trainer._call_callback_hooks("on_predict_batch_end",predictions,batch,batch_idx,dataloader_idx)self.trainer._call_lightning_module_hook("on_predict_batch_end",predictions,batch,batch_idx,dataloader_idx)self.batch_progress.increment_completed()ifself.should_store_predictions:self.predictions.append(move_data_to_device(predictions,torch.device("cpu")))def_build_kwargs(self,batch:Any,batch_idx:int,dataloader_idx:int)->Dict[str,Any]:"""Assembles the keyword arguments for the ``predict_step`` Args: batch: the current batch to run the prediction on batch_idx: the index of the current batch dataloader_idx: the index of the dataloader producing the current batch Returns: the dictionary containing all the keyboard arguments for the predict step """step_kwargs=OrderedDict([("batch",batch),("batch_idx",batch_idx)])ifself._num_dataloaders>1:step_kwargs["dataloader_idx"]=dataloader_idxreturnstep_kwargsdef_get_batch_indices(self,dataloader_idx:int)->List[List[int]]:"""Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`."""# the batch_sampler is not be defined in case of CombinedDataLoadersassertself.trainer.predict_dataloadersbatch_sampler=getattr(self.trainer.predict_dataloaders[dataloader_idx],"batch_sampler",None,)ifisinstance(batch_sampler,IndexBatchSamplerWrapper):returnbatch_sampler.seen_batch_indiceswarning_cache.warn("Lightning couldn't infer the indices fetched for your dataloader.")return[]
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.