from torch import tensor from torchmetrics.detection import PanopticQuality preds = tensor([[[[6, 0], [0, 0], [6, 0], [6, 0]], [[0, 0], [0, 0], [6, 0], [0, 1]], [[0, 0], [0, 0], [6, 0], [0, 1]], [[0, 0], [7, 0], [6, 0], [1, 0]], [[0, 0], [7, 0], [7, 0], [7, 0]]]]) target = tensor([[[[6, 0], [0, 1], [6, 0], [0, 1]], [[0, 1], [0, 1], [6, 0], [0, 1]], [[0, 1], [0, 1], [6, 0], [1, 0]], [[0, 1], [7, 0], [1, 0], [1, 0]], [[0, 1], [7, 0], [7, 0], [7, 0]]]]) metric = PanopticQuality(things = {0, 1}, stuffs = {6, 7}) metric.update(preds, target) fig_, ax_ = metric.plot()