Metrics or Callbacks?


I am trying to write some callbacks to compute a few evaluation metrics and to store the predictions of my model. I have already come across the Metrics class but I want to keep the evaluation code separate from the model code so I was thinking of writing it as callbacks.

I need some help understanding why we need the Metrics class in the first place and why can’t I just use callbacks to compute results? I have written a callback for accuracy which can be found here but I have my doubts about whether it is correct or not? and under which circumstances would it not work? I have a hunch that there might be issues in distributed setups but does it work well in single accelerator environments?

My model returns the following dictionary as output of *_step functions.

return {'loss': loss, 'idx': idx, 'pred': pred, 'gt': gt}

Additionally, is it possible to use metric class in a callback? I am generally concerned about whether my approach with callbacks would work or not since I want to compute metrics besides those provided as a part of pytorch lightning. Any help would be appreciated.


I was wondering the same thing.
A related issue that I encounter is this: In callbacks I need the input and the output of a model to compute metrics. But in order to have them there, I need to return them from LightningModule.validation_step(). However, everything LightningModule.validation_step() returns is saved so my memory fills up pretty fast. Can’t we have the option NOT to save outputs?
Apologies if my replay shouldn’t have been posted here.

1 Like

Completely agree. Any updates on this, or better practices?

I was also wondering the same thing.

Callbacks seems like the most convenient place to add and log metrics, but since all the step outputs are saved in memory until the end of the epoch it’s not practical.

is there a way to disable the saving of all step outputs to memory?
Or perhaps a different way to work with callbacks in this manner?


Hello @usmanshahid,

Writing the evaluation code as a callback is a bad idea because you will have to implement the reduction and gather operations from all the processes yourself. This is already done for you in metrics. Have a look on how to implement your own metrics.

from torchmetrics import Metric

class MyAccuracy(Metric):
    def __init__(self):
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")

    def update(self, preds: torch.Tensor, target: torch.Tensor):
        preds, target = self._input_format(preds, target)
        assert preds.shape == target.shape

        self.correct += torch.sum(preds == target) += target.numel()

    def compute(self):
        return self.correct.float() /

The MyAccuracy class inherits from the Metric class which implements the add_state method that manages the synchronization between processes for you. This is quite useful since you don’t have to rewrite it and you are sure that it works with lightning.

The sample accuracy class you’ve given only works with one process. If launched in a distributed environment (and depending on which distributed implementation) what will most likely happen is you will get multiple instances of your callback that do not talk to each other. So the metrics calculated on other gpus are not taken into account.

So NO don’t use callbacks to calculate the metrics.

If you want to decouple the evaluation code from the model code (which is a great design pattern and how you actually should do it), create your own metric classes (use add_state to synchronise your metrics) and pass these classes as parameters to your model. This is also known as dependency injection if you want to read more about it.

Also, as @Yoni_Moskovich noted, you should be mindful of what you return in the validation step. You should only return light objects like losses or metrics. You are returning ground truth and prediction objects. In the case of images this can quickly fill up your memory since they are saved until on_validation_end step. If you want to save them for later visualization use a logger like tensorboard or store them directly to disk.


1 Like

Thanks all for your help! Can these discussions be summed up as follows?

  1. Just return light-weighted objects, and not return heavy objects or your memory will use up quickly.
  2. Use TorchMetrics to compute metrics because it can sync data automatically.
  3. If metrics are self-implemented and computed by callbacks, it can be called by different process in distributed training which may lead to diffusion.