import torch from torchmetrics.wrappers import MinMaxMetric from torchmetrics.classification import BinaryAccuracy metric = MinMaxMetric(BinaryAccuracy()) values = [ ] for _ in range(3): values.append(metric(torch.randint(2, (20,)), torch.randint(2, (20,)))) fig_, ax_ = metric.plot(values)