Source code for pytorch_lightning.loops.dataloader.evaluation_loop
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Sequence
from deprecate.utils import void
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loops.dataloader import DataLoaderLoop
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
from pytorch_lightning.trainer.connectors.logger_connector.result import _OUT_DICT, ResultCollection
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
[docs]class EvaluationLoop(DataLoaderLoop):
    """Loops over all dataloaders for evaluation."""
    def __init__(self):
        super().__init__()
        self.outputs: List[EPOCH_OUTPUT] = []
        self.epoch_loop = EvaluationEpochLoop()
        self._results = ResultCollection(training=False)
        self._outputs: List[EPOCH_OUTPUT] = []
        self._max_batches: List[int] = []
        self._has_run: bool = False
    @property
    def num_dataloaders(self) -> int:
        """Returns the total number of dataloaders."""
        # case where user does:
        # return dl1, dl2
        dataloaders = self.dataloaders
        if dataloaders is None:
            return 0
        length = len(dataloaders)
        if length > 0 and isinstance(dataloaders[0], (list, tuple)):
            length = len(dataloaders[0])
        return length
    @property
    def dataloaders(self) -> Sequence[DataLoader]:
        """Returns the validation or test dataloaders."""
        if self.trainer.testing:
            return self.trainer.test_dataloaders
        return self.trainer.val_dataloaders
[docs]    def connect(self, epoch_loop: EvaluationEpochLoop):
        """Connect the evaluation epoch loop with this loop."""
        self.epoch_loop = epoch_loop
    @property
    def done(self) -> bool:
        """Returns whether all dataloaders are processed or evaluation should be skipped altogether."""
        return super().done or self.skip
    @property
    def skip(self) -> bool:
        """Returns whether the evaluation should be skipped."""
        max_batches = self._get_max_batches()
        return sum(max_batches) == 0
[docs]    def reset(self) -> None:
        """Resets the internal state of the loop."""
        self._max_batches = self._get_max_batches()
        # bookkeeping
        self.outputs = []
        if isinstance(self._max_batches, int):
            self._max_batches = [self._max_batches] * len(self.dataloaders)
        super().reset()
[docs]    def on_run_start(self, *args: Any, **kwargs: Any) -> None:
        """Runs the ``_on_evaluation_model_eval``, ``_on_evaluation_start`` and ``_on_evaluation_epoch_start``
        hooks."""
        void(*args, **kwargs)
        # hook
        self._on_evaluation_model_eval()
        self.trainer.lightning_module.zero_grad()
        self._on_evaluation_start()
        self._on_evaluation_epoch_start()
[docs]    def advance(self, *args: Any, **kwargs: Any) -> None:
        """Performs evaluation on one single dataloader."""
        void(*args, **kwargs)
        dataloader_idx: int = self.current_dataloader_idx
        dataloader = self.trainer.training_type_plugin.process_dataloader(self.current_dataloader)
        self.data_fetcher = dataloader = self.trainer._data_connector.get_profiled_dataloader(
            dataloader, dataloader_idx=dataloader_idx
        )
        dl_max_batches = self._max_batches[dataloader_idx]
        dl_outputs = self.epoch_loop.run(dataloader, dataloader_idx, dl_max_batches, self.num_dataloaders)
        # store batch level output per dataloader
        self.outputs.append(dl_outputs)
        if not self.trainer.sanity_checking:
            # indicate the loop has run
            self._has_run = True
[docs]    def on_run_end(self) -> List[_OUT_DICT]:
        """Runs the ``_on_evaluation_epoch_end`` hook."""
        outputs = self.outputs
        # free memory
        self.outputs = []
        # with a single dataloader don't pass a 2D list
        if len(outputs) > 0 and self.num_dataloaders == 1:
            outputs = outputs[0]
        # lightning module method
        self._evaluation_epoch_end(outputs)
        # hook
        self._on_evaluation_epoch_end()
        # log epoch metrics
        eval_loop_results = self.trainer.logger_connector.update_eval_epoch_metrics()
        # hook
        self._on_evaluation_end()
        # enable train mode again
        self._on_evaluation_model_train()
        return eval_loop_results
    def _get_max_batches(self) -> List[int]:
        """Returns the max number of batches for each dataloader."""
        if self.trainer.testing:
            max_batches = self.trainer.num_test_batches
        else:
            if self.trainer.sanity_checking:
                self.trainer.num_sanity_val_batches = [
                    min(self.trainer.num_sanity_val_steps, val_batches) for val_batches in self.trainer.num_val_batches
                ]
                max_batches = self.trainer.num_sanity_val_batches
            else:
                max_batches = self.trainer.num_val_batches
        return max_batches
    def _reload_evaluation_dataloaders(self) -> None:
        """Reloads dataloaders if necessary."""
        if self.trainer.testing:
            self.trainer.reset_test_dataloader()
        elif self.trainer.val_dataloaders is None or self.trainer._should_reload_val_dl:
            self.trainer.reset_val_dataloader()
    def _on_evaluation_start(self, *args: Any, **kwargs: Any) -> None:
        """Runs ``on_{validation/test}_start`` hooks."""
        assert self._results is not None
        self._results.to(device=self.trainer.lightning_module.device)
        if self.trainer.testing:
            self.trainer.call_hook("on_test_start", *args, **kwargs)
        else:
            self.trainer.call_hook("on_validation_start", *args, **kwargs)
    def _on_evaluation_model_eval(self) -> None:
        """Sets model to eval mode."""
        if self.trainer.testing:
            self.trainer.call_hook("on_test_model_eval")
        else:
            self.trainer.call_hook("on_validation_model_eval")
    def _on_evaluation_model_train(self) -> None:
        """Sets model to train mode."""
        model_ref = self.trainer.lightning_module
        if self.trainer.testing:
            model_ref.on_test_model_train()
        else:
            model_ref.on_validation_model_train()
    def _on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
        """Runs ``on_{validation/test}_end`` hook."""
        if self.trainer.testing:
            self.trainer.call_hook("on_test_end", *args, **kwargs)
        else:
            self.trainer.call_hook("on_validation_end", *args, **kwargs)
        # reset the logger connector state
        self.trainer.logger_connector.reset_results()
    def _on_evaluation_epoch_start(self, *args: Any, **kwargs: Any) -> None:
        """Runs ``on_epoch_start`` and ``on_{validation/test}_epoch_start`` hooks."""
        self.trainer.logger_connector.on_epoch_start()
        self.trainer.call_hook("on_epoch_start", *args, **kwargs)
        if self.trainer.testing:
            self.trainer.call_hook("on_test_epoch_start", *args, **kwargs)
        else:
            self.trainer.call_hook("on_validation_epoch_start", *args, **kwargs)
    def _evaluation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None:
        """Runs ``{validation/test}_epoch_end``"""
        # inform logger the batch loop has finished
        self.trainer.logger_connector.epoch_end_reached()
        # call the model epoch end
        model = self.trainer.lightning_module
        # unset dataloader_idx in model
        model._current_dataloader_idx = None
        if self.trainer.testing:
            if is_overridden("test_epoch_end", model):
                model._current_fx_name = "test_epoch_end"
                model.test_epoch_end(outputs)
        else:
            if is_overridden("validation_epoch_end", model):
                model._current_fx_name = "validation_epoch_end"
                model.validation_epoch_end(outputs)
    def _on_evaluation_epoch_end(self) -> None:
        """Runs ``on_{validation/test}_epoch_end`` hook."""
        hook_name = "on_test_epoch_end" if self.trainer.testing else "on_validation_epoch_end"
        self.trainer.call_hook(hook_name)
        self.trainer.call_hook("on_epoch_end")
        self.trainer.logger_connector.on_epoch_end()