from torchmetrics.aggregation import SumMetric metric = SumMetric() metric.update([1, 2, 3]) fig_, ax_ = metric.plot()