Metrics or Callbacks?

Hello @usmanshahid,

Writing the evaluation code as a callback is a bad idea because you will have to implement the reduction and gather operations from all the processes yourself. This is already done for you in metrics. Have a look on how to implement your own metrics.

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        super().__init__()
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target)
        self.total += target.numel()

    def compute(self):
        return self.correct.float() / self.total

The MyAccuracy class inherits from the Metric class which implements the add_state method that manages the synchronization between processes for you. This is quite useful since you don’t have to rewrite it and you are sure that it works with lightning.

The sample accuracy class you’ve given only works with one process. If launched in a distributed environment (and depending on which distributed implementation) what will most likely happen is you will get multiple instances of your callback that do not talk to each other. So the metrics calculated on other gpus are not taken into account.

So NO don’t use callbacks to calculate the metrics.

If you want to decouple the evaluation code from the model code (which is a great design pattern and how you actually should do it), create your own metric classes (use add_state to synchronise your metrics) and pass these classes as parameters to your model. This is also known as dependency injection if you want to read more about it.

Also, as @Yoni_Moskovich noted, you should be mindful of what you return in the validation step. You should only return light objects like losses or metrics. You are returning ground truth and prediction objects. In the case of images this can quickly fill up your memory since they are saved until on_validation_end step. If you want to save them for later visualization use a logger like tensorboard or store them directly to disk.

MS

1 Like