I know how to calculate and accumulate confusion matrix and then calculate per class accuracy and overall accuracy with pytorch. However, I found it was not easy to do that with pytorch lightning.
Is there any way in pytorch lightning that can generate or log per-class accuracy for the entire validation dataset?
I found an example online to log “accuracy” like this:
class MyModel(LightningModule):
def __init__(self):
...
self.accuracy = torchmetrics.Accuracy()
def training_step(self, batch, batch_idx):
x, y = batch
preds = self(x)
...
# log step metric
self.accuracy(preds, y)
self.log('train_acc_step', self.accuracy)
...
def training_epoch_end(self, outs):
# log epoch metric
self.log('train_acc_epoch', self.accuracy)
But I don’t know how to obtain per-class accuracy here. For example, if I calculate a confusion matrix in training_step(), the confusion matrix is only for a single batch and it is possible some classes don’t even exist in this batch.