import torch from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity metric = LearnedPerceptualImagePatchSimilarity() metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)) fig_, ax_ = metric.plot()