import torch from torchmetrics.nominal import TschuprowsT metric = TschuprowsT(num_classes=5) metric.update(torch.randint(0, 4, (100,)), torch.randint(0, 4, (100,))) fig_, ax_ = metric.plot()