import torch from torchmetrics.image.inception import InceptionScore metric = InceptionScore() values = [ ] for _ in range(3): # we index by 0 such that only the mean value is plotted values.append(metric(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8))[0]) fig_, ax_ = metric.plot(values)