import torch from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) values = [] for _ in range(10): values.append(metrics(torch.rand(10), torch.randint(2, (10,)))) fig_, ax_ = metrics.plot(values, together=True)