According to:
TorchMetrics in PyTorch Lightning — PyTorch-Metrics 1.3.0.post0 documentation
We are recommended to instantiate three torchmetrics (including test) when logging the metric object and letting Lightning take care of when to reset the metric etc. Here is the official code (without test):
class MyModule(LightningModule):
def __init__(self):
...
self.train_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
self.valid_acc = torchmetrics.classification.Accuracy(task="multiclass", num_classes=num_classes)
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
batch_value = self.train_acc(preds, y)
self.log('train_acc_step', batch_value)
def on_train_epoch_end(self):
self.train_acc.reset()
def validation_step(self, batch, batch_idx):
logits = self(x)
...
self.valid_acc.update(logits, y)
def on_validation_epoch_end(self, outputs):
self.log('valid_acc_epoch', self.valid_acc.compute())
self.valid_acc.reset()
My question is:
- since we call the
torchmetrics.reset()
on_train_epoch_end
,on_validation_epoch_end
andon_test_epoch_end
, we only need one torchmetric to calculate all, is that right? - When we just use
torchmetrics.forward()
to calculate the metrics of the inputs, the internal state doen’t matter (eventorchmetrics.reset()
is redundant), is that right? - When calculating
torchmetrics
with the internal state (likeFID
scores), sincequestion 1
, we only need one torchmetric to calculate all, is that right? - Is there a potential risk of doing this? Will
torchmetrics.compute()
in my second paragraph of code still works properly underddp
mode ?
Here is my code for torchmetrics which just use torchmetrics.forward()
to calculate the metrics of the inputs:
class mylightningmodule(pl.LightningModule):
def __init__(self, metric=None, **kwargs):
super().__init__()
# lightning code
self.save_hyperparameters(ignore=['metric'])
self.mIoU = metric
def training_step(self, batch, batch_idx):
mIoU = self.mIoU(preds=predict_masks, target=masks)
self.log('train_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
return loss
def validation_step(self, batch, batch_idx):
mIoU = self.mIoU(preds=predict_masks, target=masks)
self.log('val_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
def test_step(self, batch, batch_idx):
mIoU = self.mIoU(preds=predict_masks, target=masks)
self.log('test_mIoU', mIoU, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
def on_train_epoch_end(self):
self.mIoU.reset()
def on_validation_epoch_end(self):
self.mIoU.reset()
def on_test_epoch_end(self):
self.mIoU.reset()
# train, validate and test code
metric = MulticlassJaccardIndex(num_classes=5)
trainer.fit(model=mylightningmodule, datamodule=data_module)
trainer.validate(model=mylightningmodule, datamodule=data_module)
trainer.test(model=mylightningmodule, datamodule=data_module)
Here is my code for torchmetrics with useful internal state (like FID
scores):
class mylightningmodule(pl.LightningModule):
def __init__(self, metric=None, **kwargs):
super().__init__()
# lightning code
self.save_hyperparameters(ignore=['metric'])
# FrechetInceptionDistance
self.FID = metric
def validation_step(self, batch, batch_idx):
# calculate FID
self.FID.update(X_0, real=True)
self.FID.update(pred_x_0, real=False)
def test_step(self, batch, batch_idx):
# calculate FID
self.FID.update(X_0, real=True)
self.FID.update(pred_x_0, real=False)
def on_validation_epoch_end(self):
FID = self.FID.compute()
self.log('val_FID', FID, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.FID.reset()
def on_test_epoch_end(self):
FID = self.FID.compute()
self.log('test_FID', FID, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.FID.reset()
# train, validate and test code
metric = FrechetInceptionDistance()
trainer.fit(model=mylightningmodule, datamodule=data_module)
trainer.validate(model=mylightningmodule, datamodule=data_module)
trainer.test(model=mylightningmodule, datamodule=data_module)