I have a problem in which I do binary classification for 2 outputs. I want to create a confusion matrix for each of them, + some other metrics, so I was looking for the best way to do this (still learning pl).
Currently I thought to put them in on_test_epoch_end()
, but with the latest update of lightening I did today, this stopped working and gives me an error. It worked before when the function was called test_epoch_end()
. The error is:
on_test_epoch_end() missing 1 required positional argument: 'outputs'
I tried putting (self, trainer, module_pl), but that gives errors too, all about the arguments. Any ideas? Maybe I am putting this confusion matrix in a very wrong place?
My code is:
def test_step(self, batch, batch_idx):
x, y = batch #not sure if instead should be batch[0], batch[1]
y_hat = self(x)
loss = self.criterion(y_hat, y)
# result = pl.EvalResult()
self.log('test_loss', loss)
self.log('length y hat', len(y_hat))
# accuracy = functional.accuracy(y_hat, y, task = 'binary')
# f1_score_pred = functional.f1_score(y_hat, y, task = 'binary'). #gives me zero, so something is wrong
# auroc = functional.auroc(y_hat, y, task = 'binary')
self.log("train_loss", loss)
# self.log("train_accuracy", accuracy)
# self.log("train_f1", f1_score_pred)
# self.log("train_auroc", auroc)
return {'preds' : y_hat, 'targets' : y}
def on_test_epoch_end(self, outputs):
preds = torch.cat([tmp['preds'] for tmp in outputs])
targets = torch.cat([tmp['targets'] for tmp in outputs])
confusion_matrix = torchmetrics.ConfusionMatrix(task = 'binary', num_classes=2)
confusion_matrix(preds, targets.int())
confusion_matrix_computed = confusion_matrix.compute().detach().cpu().numpy().astype(int)
df_cm = pd.DataFrame(confusion_matrix_computed)
plt.figure(figsize = (10,7))
fig_ = sns.heatmap(df_cm, annot=True, cmap='Spectral').get_figure()
plt.close(fig_)
# self.logger("Confusion matrix: ")
self.loggers[0].experiment.add_figure("Confusion matrix", fig_, self.current_epoch)