import torch from torchmetrics.detection.mean_ap import MeanAveragePrecision preds = lambda: [dict( boxes=torch.tensor([[258.0, 41.0, 606.0, 285.0]]) + torch.randint(10, (1,4)), scores=torch.tensor([0.536]) + 0.1*torch.rand(1), labels=torch.tensor([0]), )] target = [dict( boxes=torch.tensor([[214.0, 41.0, 562.0, 285.0]]), labels=torch.tensor([0]), )] metric = MeanAveragePrecision() vals = [] for _ in range(20): vals.append(metric(preds(), target)) fig_, ax_ = metric.plot(vals)