import torch from torchmetrics.image.inception import InceptionScore metric = InceptionScore() metric.update(torch.randint(0, 255, (50, 3, 299, 299), dtype=torch.uint8)) fig_, ax_ = metric.plot() # the returned plot only shows the mean value by default