from torchmetrics.aggregation import MeanMetric metric = MeanMetric() metric.update([1, 2, 3]) fig_, ax_ = metric.plot()