Frechet Inception Distance (FID)¶
Module Interface¶
- class torchmetrics.image.fid.FrechetInceptionDistance(feature=2048, reset_real_features=True, normalize=False, input_img_size=(3, 299, 299), **kwargs)[source]¶
Calculate Fréchet inception distance (FID) which is used to access the quality of generated images.
\[FID = \|\mu - \mu_w\|^2 + tr(\Sigma + \Sigma_w - 2(\Sigma \Sigma_w)^{\frac{1}{2}})\]where \(\mathcal{N}(\mu, \Sigma)\) is the multivariate normal distribution estimated from Inception v3 (fid ref1) features calculated on real life images and \(\mathcal{N}(\mu_w, \Sigma_w)\) is the multivariate normal distribution estimated from Inception v3 features calculated on generated (fake) images. The metric was originally proposed in fid ref1.
Using the default feature extraction (Inception v3 using the original weights from fid ref2), the input is expected to be mini-batches of 3-channel RGB images of shape
(3xHxW)
. If argumentnormalize
isTrue
images are expected to be dtypefloat
and have values in the[0,1]
range, else ifnormalize
is set toFalse
images are expected to have dtypeuint8
and take values in the[0, 255]
range. All images will be resized to 299 x 299 which is the size of the original training data. The boolian flagreal
determines if the images should update the statistics of the real distribution or the fake distribution.Using custom feature extractor is also possible. One can give a torch.nn.Module as feature argument. This custom feature extractor is expected to have output shape of
(1, num_features)
. This would change the used feature extractor from default (Inception v3) to the given network. In case network doesn’t havenum_features
attribute, a random tensor will be given to the network to infer feature dimensionality. Size of this tensor can be controlled byinput_img_size
argument and type of the tensor can be controlled withnormalize
argument (True
uses float32 tensors andFalse
uses int8 tensors). In this case, update method expects to have the tensor given to imgs argument to be in the correct shape and type that is compatible to the custom feature extractor.This metric is known to be unstable in its calculatations, and we recommend for the best results using this metric that you calculate using torch.float64 (default is torch.float32) which can be set using the .set_dtype method of the metric.
Hint
Using this metric with the default feature extractor requires that
torch-fidelity
is installed. Either install aspip install torchmetrics[image]
orpip install torch-fidelity
As input to
forward
andupdate
the metric accepts the following inputimgs
(Tensor
): tensor with images feed to the feature extractor withreal
(bool
): bool indicating ifimgs
belong to the real or the fake distribution
As output of forward and compute the metric returns the following output
fid
(Tensor
): float scalar tensor with mean FID value over samples
- Parameters:
feature¶ (
Union
[int
,Module
]) –Either an integer or
nn.Module
:an integer will indicate the inceptionv3 feature layer to choose. Can be one of the following: 64, 192, 768, 2048
an
nn.Module
for using a custom feature extractor. Expects that its forward method returns an(N,d)
matrix whereN
is the batch size andd
is the feature size.
reset_real_features¶ (
bool
) – Whether to also reset the real features. Since in many cases the real dataset does not change, the features can be cached them to avoid recomputing them which is costly. Set this toFalse
if your dataset does not change.Argument for controlling the input image dtype normalization:
If default feature extractor is used, controls whether input imgs have values in range [0, 1] or not:
True: if input imgs have values ranged in [0, 1]. They are cast to int8/byte tensors.
False: if input imgs have values ranged in [0, 255]. No casting is done.
If custom feature extractor module is used, controls type of the input img tensors:
True: if input imgs are expected to be in the data type of torch.float32.
False: if input imgs are expected to be in the data type of torch.int8.
input_img_size¶ (
tuple
[int
,int
,int
]) – tuple of integers. Indicates input img size to the custom feature extractor network if provided.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If torch version is lower than 1.9
ModuleNotFoundError – If
feature
is set to anint
(default settings) andtorch-fidelity
is not installedValueError – If
feature
is set to anint
not in [64, 192, 768, 2048]TypeError – If
feature
is not anstr
,int
ortorch.nn.Module
ValueError – If
reset_real_features
is not anbool
Example
>>> from torch import rand >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> fid = FrechetInceptionDistance(feature=64) >>> # generate two slightly overlapping image intensity distributions >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> fid.update(imgs_dist1, real=True) >>> fid.update(imgs_dist2, real=False) >>> fid.compute() tensor(12.6388)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> imgs_dist1 = torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = FrechetInceptionDistance(feature=64) >>> metric.update(imgs_dist1, real=True) >>> metric.update(imgs_dist2, real=False) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.fid import FrechetInceptionDistance >>> imgs_dist1 = lambda: torch.randint(0, 200, (100, 3, 299, 299), dtype=torch.uint8) >>> imgs_dist2 = lambda: torch.randint(100, 255, (100, 3, 299, 299), dtype=torch.uint8) >>> metric = FrechetInceptionDistance(feature=64) >>> values = [ ] >>> for _ in range(3): ... metric.update(imgs_dist1(), real=True) ... metric.update(imgs_dist2(), real=False) ... values.append(metric.compute()) ... metric.reset() >>> fig_, ax_ = metric.plot(values)