from torch import tensor from torchmetrics.detection import ModifiedPanopticQuality preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], [[0, 0], [0, 0], [6, 0], [0, 1]], [[0, 0], [0, 0], [6, 0], [0, 1]], [[0, 0], [7, 0], [6, 0], [1, 0]], [[0, 0], [7, 0], [7, 0], [7, 0]]]]) target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], [[0, 1], [0, 1], [6, 0], [0, 1]], [[0, 1], [0, 1], [6, 0], [1, 0]], [[0, 1], [7, 0], [1, 0], [1, 0]], [[0, 1], [7, 0], [7, 0], [7, 0]]]]) metric = ModifiedPanopticQuality(things = {0, 1}, stuffs = {6, 7}) vals = [] for _ in range(20): vals.append(metric(preds, target)) fig_, ax_ = metric.plot(vals)