import torch from torchmetrics.wrappers import BootStrapper from torchmetrics.regression import MeanSquaredError metric = BootStrapper(MeanSquaredError(), num_bootstraps=20) metric.update(torch.randn(100,), torch.randn(100,)) fig_, ax_ = metric.plot()