Confusions about torchmetrics in pytorch_lightning

According to:

TorchMetrics in PyTorch Lightning — PyTorch-Metrics 1.3.0.post0 documentation

We are recommended to instantiate three torchmetrics (including test) when logging the metric object and letting Lightning take care of when to reset the metric etc. Here is the official code (without test):

class MyModule(LightningModule):

    def __init__(self):
        ...
        self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
        self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)

    def training_step(self, batch, batch_idx):
        x, y = batch
        preds = self(x)
        ...
        batch_value = self.train_acc(preds, y)
        self.log('train_acc_step', batch_value)

    def on_train_epoch_end(self):
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):
        logits = self(x)
        ...
        self.valid_acc.update(logits, y)

    def on_validation_epoch_end(self, outputs):
        self.log('valid_acc_epoch', self.valid_acc.compute())
        self.valid_acc.reset()

My question is:

  1. since we call the torchmetrics.reset() on_train_epoch_end, on_validation_epoch_end and on_test_epoch_end, we only need one torchmetric to calculate all, is that right?
  2. When we just use torchmetrics.forward() to calculate the metrics of the inputs, the internal state doen’t matter (even torchmetrics.reset() is redundant), is that right?
  3. When calculating torchmetrics with the internal state (like FID scores), since question 1, we only need one torchmetric to calculate all, is that right?
  4. Is there a potential risk of doing this? Will torchmetrics.compute() in my second paragraph of code still works properly under ddp mode ?

Here is my code for torchmetrics which just use torchmetrics.forward() to calculate the metrics of the inputs:

class mylightningmodule(pl.LightningModule):
    def __init__(self, metric=None, **kwargs):
        super().__init__()
        # lightning code
        self.save_hyperparameters(ignore=['metric'])
        self.mIoU = metric

    def training_step(self, batch, batch_idx):
        mIoU = self.mIoU(preds=predict_masks, target=masks)
        self.log('train_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        mIoU = self.mIoU(preds=predict_masks, target=masks)
        self.log('val_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def test_step(self, batch, batch_idx):
        mIoU = self.mIoU(preds=predict_masks, target=masks)
        self.log('test_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)

    def on_train_epoch_end(self):
        self.mIoU.reset()

    def on_validation_epoch_end(self):
        self.mIoU.reset()

    def on_test_epoch_end(self):
        self.mIoU.reset()

# train, validate and test code
metric = MulticlassJaccardIndex(num_classes=5)
trainer.fit(model=mylightningmodule, datamodule=data_module)
trainer.validate(model=mylightningmodule, datamodule=data_module)
trainer.test(model=mylightningmodule, datamodule=data_module)

Here is my code for torchmetrics with useful internal state (like FID scores):

class mylightningmodule(pl.LightningModule):
    def __init__(self, metric=None, **kwargs):
        super().__init__()
        # lightning code
        self.save_hyperparameters(ignore=['metric'])
        # FrechetInceptionDistance
        self.FID = metric

    def validation_step(self, batch, batch_idx):
        # calculate FID
        self.FID.update(X_0, real=True)
        self.FID.update(pred_x_0, real=False)

    def test_step(self, batch, batch_idx):
        # calculate FID
        self.FID.update(X_0, real=True)
        self.FID.update(pred_x_0, real=False)

    def on_validation_epoch_end(self):
        FID = self.FID.compute()
        self.log('val_FID', FID, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.FID.reset()

    def on_test_epoch_end(self):
        FID = self.FID.compute()
        self.log('test_FID', FID, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
        self.FID.reset()

# train, validate and test code
metric = FrechetInceptionDistance()
trainer.fit(model=mylightningmodule, datamodule=data_module)
trainer.validate(model=mylightningmodule, datamodule=data_module)
trainer.test(model=mylightningmodule, datamodule=data_module)

@awaelchli @SkafteNicki @jirka

@sznflash I believe this is just a best practice recommendation so that the user doesn’t accidentally mix up the statistics across the different datasets.

1 Like

According to:

LightningModule — PyTorch Lightning 2.1.4 documentation

The on_validation_epoch_end() method is called before on_train_epoch_end() .
So maybe we should do the reset the metrics on_train_epoch_start(), on_validation_epoch_start() and on_test_epoch_start().

Of caurse, the best way is instantiating three torchmetrics (including test). :smiling_face_with_three_hearts:
@awaelchli

After experimenting, I believe it is best to instantiate separate torchmetrics for each train, val, and test set. This can help avoid statistical data confusion and improve readability.

Hey this makes me a bit confused too.

I define a custom Metric class like this:


class Precision(torchmetrics.Metric):
    def __init__(self):
        super().__init__()
        
        self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
        self.add_state("num_preds", default=torch.tensor(0), dist_reduce_fx="sum")
        
    def update(self, pred, targets):
        assert len(targets) == len(pred)
        
        for idx in range(len(targets)):
            ground_truth_entities = targets[idx].split('@')
            pred_entities = pred[idx].split('@')
            
            self.num_preds += len(pred_entities)
            self.total += len(ground_truth_entities)
            
            for entity in set(pred_entities):
                if entity in ground_truth_entities:
                    self.correct += 1
    
    def compute(self):
        return self.correct.float() / self.num_preds.float(), self.correct.float()/self.total.float()
        

but inside the LightningModule I just only initialize an instance of Precision class to calculate metric score in both training_step and validation_step functions. But it actually gives different number on train and validation set:

...
self.precision_score = Precision()
def training_step():
    ...
    precision_score, recall_score = self.precision_score(predictions, output_sequences)

def validation_step():
    ...
    precision_score, recall_score = self.precision_score(predictions, output_sequences)

Hmmm how?

@dinhngoc123 I’ve noticed that your metric saves internal state, and I don’t recommend mixing their use. We should instantiate a separate metric for train, validation, and test, respectively. If you insist on mixing their use, I suggest resetting the entire metric at the start of each type of epoch.