At each train/validation step I am storing the number of TP, FP, TN, FN. On epoch end I would like to concatenate them but it throws me an error, the code:
def shared_step(self, batch, stage):
image, mask = batch
out = self.forward(image)
loss = self.criterion(out, mask.long())
tp, fp, fn, tn = smp.metrics.get_stats(torch.argmax(out, 1).unsqueeze(1), mask.long(), mode='multiclass', num_classes = 5)
self.log(f'{stage}_loss', loss)
return {"loss": loss, "tp": tp, "fp": fp, "fn": fn, "tn": tn}
def shared_epoch_end(self, outputs, stage):
tp = torch.cat([x["tp"] for x in outputs])
fp = torch.cat([x["fp"] for x in outputs])
fn = torch.cat([x["fn"] for x in outputs])
tn = torch.cat([x["tn"] for x in outputs])
iou = {f"{stage}_IoU": smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro-imagewise")}
self.log_dict(iou, prog_bar=True)
def validation_step(self, batch, batch_idx):
return self.shared_step(batch, "valid")
def on_validation_epoch_end(self, outputs):
return self.shared_epoch_end(outputs, "valid")
And this is the error:
TypeError: on_validation_epoch_end() missing 1 required positional argument: 'outputs'