import torch from torchmetrics.wrappers import MinMaxMetric from torchmetrics.classification import BinaryAccuracy metric = MinMaxMetric(BinaryAccuracy()) metric.update(torch.randint(2, (20,)), torch.randint(2, (20,))) fig_, ax_ = metric.plot()