import torch from torchmetrics.image.kid import KernelInceptionDistance imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) metric = KernelInceptionDistance(subsets=3, subset_size=20) values = [ ] for _ in range(3): metric.update(imgs_dist1(), real=True) metric.update(imgs_dist2(), real=False) values.append(metric.compute()[0]) metric.reset() fig_, ax_ = metric.plot(values)