from torchmetrics.aggregation import MaxMetric metric = MaxMetric() metric.update([1, 2, 3]) fig_, ax_ = metric.plot()