import torch from torchmetrics.wrappers import Running from torchmetrics.aggregation import SumMetric metric = Running(SumMetric(), 2) metric.update(torch.randn(20, 2)) fig_, ax_ = metric.plot()