import torch from torchmetrics.wrappers import Running from torchmetrics.aggregation import SumMetric metric = Running(SumMetric(), 2) values = [ ] for _ in range(3): values.append(metric(torch.randn(20, 2))) fig_, ax_ = metric.plot(values)