Kernel Inception Distance¶
Module Interface¶
- class torchmetrics.image.kid.KernelInceptionDistance(feature=2048, subsets=100, subset_size=1000, degree=3, gamma=None, coef=1.0, reset_real_features=True, normalize=False, **kwargs)[source]¶
Calculate Kernel Inception Distance (KID) which is used to access the quality of generated images.
\[KID = MMD(f_{real}, f_{fake})^2\]where \(MMD\) is the maximum mean discrepancy and \(I_{real}, I_{fake}\) are extracted features from real and fake images, see kid ref1 for more details. In particular, calculating the MMD requires the evaluation of a polynomial kernel function \(k\)
\[k(x,y) = (\gamma * x^T y + coef)^{degree}\]which controls the distance between two features. In practise the MMD is calculated over a number of subsets to be able to both get the mean and standard deviation of KID.
Using the default feature extraction (Inception v3 using the original weights from kid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3xHxW)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0,1]
range, else ifnormalize
is set toFalse
images are expected to have dtypeuint8
and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flagreal
determines if the images should update the statistics of the real distribution or the fake distribution.Using custom feature extractor is also possible. One can give a torch.nn.Module as feature argument. This custom feature extractor is expected to have output shape of
(1, num_features)
This would change the used feature extractor from default (Inception v3) to the given network.normalize
argument won’t have any effect and update method expects to have the tensor given to imgs argument to be in the correct shape and type that is compatible to the custom feature extractor.Hint
Using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor of shape(N,C,H,W)
real
(bool): bool indicating ifimgs
belong to the real or the fake distribution
As output of forward and compute the metric returns the following output
kid_mean
(Tensor
): float scalar tensor with mean value over subsetskid_std
(Tensor
): float scalar tensor with standard deviation value over subsets
- Parameters:
feature¶ (
Union
[str
,int
,Module
]) –Either an str, integer or
nn.Module
:an str or integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: ‘logits_unbiased’, 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
subsets¶ (
int
) – Number of subsets to calculate the mean and standard deviation scores oversubset_size¶ (
int
) – Number of randomly picked samples in each subsetgamma¶ (
Optional
[float
]) – Scale-length of polynomial kernel. If set toNone
will be automatically set to the feature sizereset_real_features¶ (
bool
) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can cached them to avoid recomputing them which is costly. Set this toFalse
if your dataset does not change.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If
feature
is set to anint
(default settings) andtorch-fidelity
is not installedValueError – If
feature
is set to anint
not in(64, 192, 768, 2048)
ValueError – If
subsets
is not an integer larger than 0ValueError – If
subset_size
is not an integer larger than 0ValueError – If
degree
is not an integer larger than 0ValueError – If
gamma
is neitherNone
or a float larger than 0ValueError – If
coef
is not an float larger than 0ValueError – If
reset_real_features
is not anbool
Example
>>> from torch import randint >>> from torchmetrics.image.kid import KernelInceptionDistance >>> kid = KernelInceptionDistance(subset_size=50) >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> kid.update(imgs_dist1, real=True) >>> kid.update(imgs_dist2, real=False) >>> kid.compute() (tensor(0.0312), tensor(0.0025))
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.kid import KernelInceptionDistance >>> imgs_dist1 = torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) >>> metric.update(imgs_dist1, real=True) >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.kid import KernelInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (30, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (30, 3, 299, 299), dtype=torch.uint8) >>> metric = KernelInceptionDistance(subsets=3, subset_size=20) >>> values = [ ] >>> for _ in range(3): ... metric.update(imgs_dist1(), real=True) ... metric.update(imgs_dist2(), real=False) ... values.append(metric.compute()[0]) ... metric.reset() >>> fig_, ax_ = metric.plot(values)