import torch from torchmetrics.classification import BinaryAveragePrecision metric = BinaryAveragePrecision() metric.update(torch.rand(20,), torch.randint(2, (20,))) fig_, ax_ = metric.plot()