import torch from torchmetrics.image import PeakSignalNoiseRatio metric = PeakSignalNoiseRatio() preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]]) target = torch.tensor([[3.0, 2.0], [1.0, 0.0]]) values = [ ] for _ in range(10): values.append(metric(preds, target)) fig_, ax_ = metric.plot(values)