import torch from torchmetrics.image import RelativeAverageSpectralError metric = RelativeAverageSpectralError() metric.update(torch.rand(4, 3, 16, 16), torch.rand(4, 3, 16, 16)) fig_, ax_ = metric.plot()