import torch from torchmetrics.detection import CompleteIntersectionOverUnion preds = [ { "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), "scores": torch.tensor([0.236, 0.56]), "labels": torch.tensor([4, 5]), } ] target = lambda : [ { "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]) + torch.randint(-10, 10, (1, 4)), "labels": torch.tensor([5]), } ] metric = CompleteIntersectionOverUnion() vals = [] for _ in range(20): vals.append(metric(preds, target())) fig_, ax_ = metric.plot(vals)