from torch import randint from torchmetrics.classification import Dice metric = Dice() metric.update(randint(2,(10,)), randint(2,(10,))) fig_, ax_ = metric.plot()