import torch from torchmetrics.image import PeakSignalNoiseRatioWithBlockedEffect metric = PeakSignalNoiseRatioWithBlockedEffect() metric.update(torch.rand(2, 1, 10, 10), torch.rand(2, 1, 10, 10)) fig_, ax_ = metric.plot()