from torch import randint from torchmetrics.classification import MulticlassConfusionMatrix metric = MulticlassConfusionMatrix(num_classes=5) metric.update(randint(5, (20,)), randint(5, (20,))) fig_, ax_ = metric.plot()