This is a quick edit of your sample code according to docs
def __init__(self, ...):
...
self.valid_acc = pl.metrics.Accuracy()
def training_step(self, batch, batch_idx):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
self.log('loss', loss, on_step=True, on_epoch=True)
return loss
def validation_step(self, batch, batch_idx):
images, labels = batch
out = self(images)
loss = F.cross_entropy(out, labels)
a, y_hat = torch.max(out, dim=1)
self.valid_acc(y_hat, y)
self.log('valid_loss', loss, on_step=True, on_epoch=True)
self.log('valid_acc', self.valid_acc, on_step=True, on_epoch=True)
def validation_epoch_end(self, outputs):
# do nothing here