import torch from torchmetrics.detection import GeneralizedIntersectionOverUnion preds = [ { "boxes": torch.tensor([[296.55, 93.96, 314.97, 152.79], [298.55, 98.96, 314.97, 151.79]]), "scores": torch.tensor([0.236, 0.56]), "labels": torch.tensor([4, 5]), } ] target = [ { "boxes": torch.tensor([[300.00, 100.00, 315.00, 150.00]]), "labels": torch.tensor([5]), } ] metric = GeneralizedIntersectionOverUnion() metric.update(preds, target) fig_, ax_ = metric.plot()