Shortcuts

Source code for pytorch_lightning.loops.dataloader.prediction_loop

from typing import Any, List, Optional, Sequence

from deprecate.utils import void
from torch.utils.data import DataLoader

from pytorch_lightning.loops.dataloader.dataloader_loop import DataLoaderLoop
from pytorch_lightning.loops.epoch.prediction_epoch_loop import PredictionEpochLoop
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import _PREDICT_OUTPUT


[docs]class PredictionLoop(DataLoaderLoop): """Loop to run over dataloaders for prediction.""" def __init__(self): super().__init__() self.predictions: Optional[List[List[Any]]] = None self.epoch_batch_indices: Optional[List[List[int]]] = None self.epoch_loop = PredictionEpochLoop() self._results = None # for `trainer._results` access self._return_predictions: bool = False @property def return_predictions(self) -> bool: """Whether to return the predictions or not.""" return self._return_predictions @return_predictions.setter def return_predictions(self, return_predictions: Optional[bool] = None) -> None: # `DDPSpawnPlugin` plugins and derivatives don't support return predictions. is_ddp_spawn = isinstance(self.trainer.training_type_plugin, DDPSpawnPlugin) if return_predictions and is_ddp_spawn: raise MisconfigurationException( "`return_predictions` should be set to `False` when using the `DDPSpawnPlugin` or children class. " f"Found {return_predictions} with training_type_plugin {type(self.trainer.training_type_plugin)}." ) # For non `DDPSpawnPlugin` plugin, the `return_predictions` is True by default unless user decide otherwise. self._return_predictions = not is_ddp_spawn if return_predictions is None else return_predictions @property def num_dataloaders(self) -> int: """Returns the number of prediction dataloaders.""" # case where user does: # return dl1, dl2 dataloaders = self.dataloaders length = len(dataloaders) if len(dataloaders) > 0 and isinstance(dataloaders[0], (list, tuple)): length = len(dataloaders[0]) return length @property def max_batches(self) -> List[int]: """The max number of batches this loop will run for each dataloader.""" max_batches = self.trainer.num_predict_batches if isinstance(max_batches, int): max_batches = [max_batches] * len(self.dataloaders) return max_batches @property def dataloaders(self) -> Sequence[DataLoader]: """Returns all prediction dataloaders.""" return self.trainer.predict_dataloaders @property def skip(self) -> bool: return sum(self.max_batches) == 0
[docs] def connect(self, epoch_loop: PredictionEpochLoop): """Connect the prediction epoch loop with this loop.""" self.epoch_loop = epoch_loop
[docs] def reset(self) -> None: """Resets the internal state of the loop for a new run.""" super().reset() self.predictions = [] self.epoch_batch_indices = []
[docs] def on_run_start(self) -> None: """Calls ``_on_predict_start`` hook.""" self._on_predict_start()
[docs] def advance(self, *args: Any, **kwargs: Any) -> None: """Predicts one entire dataloader.""" void(*args, **kwargs) dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader) dataloader_iter = enumerate(dataloader) dl_max_batches = self.max_batches[self.current_dataloader_idx] dl_predictions, dl_batch_indices = self.epoch_loop.run( dataloader_iter, self.current_dataloader_idx, dl_max_batches, self.num_dataloaders, self.return_predictions ) self.predictions.append(dl_predictions) self.epoch_batch_indices.append(dl_batch_indices)
[docs] def on_run_end(self) -> _PREDICT_OUTPUT: """Calls ``on_predict_epoch_end`` and ``on_predict_end`` hooks and returns results from all dataloaders.""" results = self._on_predict_epoch_end() self._on_predict_end() return results
def _on_predict_start(self) -> None: """Sets model to eval mode and disables gradients. Also calls ``on_predict_start`` and ``on_predict_epoch_start`` hooks. """ # enable eval mode + no grads self._on_predict_model_eval() self.trainer.lightning_module.zero_grad() # hook self.trainer.call_hook("on_predict_start") self.trainer.call_hook("on_predict_epoch_start") def _on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]: """Calls ``on_predict_epoch_end`` hook. Returns: the results for all dataloaders """ results = self.predictions self.trainer.call_hook("on_predict_epoch_end", results) if self.return_predictions: return results[0] if self.num_dataloaders == 1 else results def _on_predict_end(self) -> None: """Resets previous gradient status and calls ``on_predict_end`` hook.""" # clear memory. the predictions are extracted in `on_predict_epoch_end`. self.predictions = [] self.epoch_batch_indices = [] # hook self.trainer.call_hook("on_predict_end") def _on_predict_model_eval(self): """Calls ``on_predict_model_eval`` hook.""" model_ref = self.trainer.lightning_module model_ref.on_predict_model_eval()

© Copyright Copyright (c) 2018-2023, William Falcon et al...

Built with Sphinx using a theme provided by Read the Docs.