import torch from torchmetrics.wrappers import MultioutputWrapper from torchmetrics.regression import R2Score metric = MultioutputWrapper(R2Score(), 2) metric.update(torch.randn(20, 2), torch.randn(20, 2)) fig_, ax_ = metric.plot()