Hi, I’m trying to log metrics only on epochs, but it doesn’t seem to work as intended.
Here is my code:
class StackedLSTM(pl.LightningModule):
...
def forward(self, x):
out, _ = self.lstm1(x)
out, _ = self.lstm2(self.dropout(out))
logits = self.fc(self.dropout(out))[:, -1]
return logits
# custom logging function that I use for train/valid/test
def log_metrics(self, loss, y, probas, stage):
acc = accuracy_score(y, probas > 0.5)
ap = average_precision_score(y, probas, average='weighted', pos_label=1)
f1 = f1_score(y, probas > 0.5, average='weighted', pos_label=1)
auroc = roc_auc_score(y, probas, average='weighted')
self.log_dict({
f'{stage}_loss': loss,
f'{stage}_acc': torch.tensor(acc),
f'{stage}_ap': torch.tensor(ap),
f'{stage}_f1': torch.tensor(f1),
f'{stage}_auroc': torch.tensor(auroc)
}, on_step=False, on_epoch=True, logger=True)
if stage == 'test':
return acc, ap, f1, auroc
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
probas = torch.sigmoid(logits).detach().cpu().numpy()
_y = y.detach().cpu().numpy()
loss = nn.BCEWithLogitsLoss()(logits, y)
self.log_metrics(loss, _y, probas, stage='train')
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
probas = torch.sigmoid(logits).detach().cpu().numpy()
_y = y.detach().cpu().numpy()
loss = nn.BCEWithLogitsLoss()(logits, y)
self.log_metrics(loss, _y, probas, stage='valid')
return loss
def test_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
probas = torch.sigmoid(logits).detach().cpu().numpy()
_y = y.detach().cpu().numpy()
loss = nn.BCEWithLogitsLoss()(logits, y)
acc, ap, f1, auroc = self.log_metrics(loss, _y, probas, stage='test')
cm = confusion_matrix(_y, probas > 0.5, normalize=None)
return {'acc': acc, 'ap': ap, 'f1': f1, 'auroc': auroc}, cm
def test_step_end(self, outputs):
self.results = outputs
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
return optimizer
I define my own log_metrics
function and call it in different stages. When I run this, although I specify on_step = False
and on_epoch = True
, the logging happens for each step. It only works correctly when I pass the full dataset as a batch (1 batch = epoch).
Am I missing something here? I’d much appreciate any help!