from torch import rand, randint from torchmetrics.aggregation import SumMetric metric = SumMetric() values = [ ] for i in range(10): values.append(metric([i, i+1])) fig_, ax_ = metric.plot(values)