from torchmetrics.aggregation import MeanMetric metric = MeanMetric() values = [ ] for i in range(10): values.append(metric([i, i+1])) fig_, ax_ = metric.plot(values)