import torch from torchmetrics import MetricCollection from torchmetrics.classification import BinaryAccuracy, BinaryPrecision, BinaryRecall metrics = MetricCollection([BinaryAccuracy(), BinaryPrecision(), BinaryRecall()]) metrics.update(torch.rand(10), torch.randint(2, (10,))) fig_ax_ = metrics.plot()