import torch from torchmetrics.wrappers import BootStrapper from torchmetrics.regression import MeanSquaredError metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) values = [ ] for _ in range(3): values.append(metric(torch.randn(100,), torch.randn(100,))) fig_, ax_ = metric.plot(values)