import torch from torchmetrics.image import PeakSignalNoiseRatio metric = PeakSignalNoiseRatio() preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) metric.update(preds, target) fig_, ax_ = metric.plot()