from torch import randn from torchmetrics.regression import KLDivergence metric = KLDivergence() values = [] for _ in range(10): values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))) fig, ax = metric.plot(values)