from torch import randn from torchmetrics.regression import KLDivergence metric = KLDivergence() metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)) fig_, ax_ = metric.plot()