Deep Image Structure And Texture Similarity (DISTS)

Module Interface

class torchmetrics.image.dists.DeepImageStructureAndTextureSimilarity(reduction='mean', **kwargs)[source]

Calculates Deep Image Structure and Texture Similarity (DISTS) score.

The metric is a full-reference image quality assessment (IQA) model that combines sensitivity to structural distortions (e.g., artifacts due to noise, blur, or compression) with a tolerance of texture resampling (exchanging the content of a texture region with a new sample of the same texture). The metric is based on a convolutional neural network (CNN) that transforms the reference and distorted images to a new representation. Within this representation, a set of measurements are developed that are sufficient to capture the appearance of a variety of different visual distortions.

As input to forward and update the metric accepts the following input

  • preds (Tensor): tensor with images of shape (N, 3, H, W)

  • target (Tensor): tensor with images of shape (N, 3, H, W)

As output of forward and compute the metric returns the following output

  • lpips (Tensor): returns float scalar tensor with average LPIPS value over samples

Parameters:
Raises:

ValueError – If reduction is not one of [“mean”, “sum”]

Example

>>>
>>> from torch import rand
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
>>> metric = DeepImageStructureAndTextureSimilarity()
>>> preds = rand(10, 3, 100, 100)
>>> target = rand(10, 3, 100, 100)
>>> metric(preds, target)
tensor(0.1882, grad_fn=<CloneBackward0>)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>>
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
>>> metric = DeepImageStructureAndTextureSimilarity()
>>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))
>>> fig_, ax_ = metric.plot()
../_images/dists-1.png
>>>
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity
>>> metric = DeepImageStructureAndTextureSimilarity()
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)))
>>> fig_, ax_ = metric.plot(values)
../_images/dists-2.png

Functional Interface

torchmetrics.functional.image.dists.deep_image_structure_and_texture_similarity(preds, target, reduction=None)[source]

Calculates Deep Image Structure and Texture Similarity (DISTS) score.

Parameters:
  • preds (Tensor) – Predicted image tensor.

  • target (Tensor) – Target image tensor.

  • reduction (Optional[Literal['sum', 'mean', 'none']]) – Reduction method for the output.

Return type:

Tensor

Returns:

DISTS Similarity score between the two images.

Example

>>>
>>> from torch import rand
>>> preds = rand(5, 3, 256, 256)
>>> target = rand(5, 3, 256, 256)
>>> deep_image_structure_and_texture_similarity(preds, target)
tensor([0.1285, 0.1344, 0.1356, 0.1277, 0.1276], grad_fn=<RsubBackward1>)
>>> deep_image_structure_and_texture_similarity(preds, target, reduction='mean')
tensor(0.1308, grad_fn=<MeanBackward0>)