import torch from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity metric = LearnedPerceptualImagePatchSimilarity() values = [ ] for _ in range(3): values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))) fig_, ax_ = metric.plot(values)