Deep Image Structure And Texture Similarity (DISTS)
Module Interface
- class torchmetrics.image.dists.DeepImageStructureAndTextureSimilarity(reduction='mean', **kwargs)[source]
Calculates Deep Image Structure and Texture Similarity (DISTS) score.
The metric is a full-reference image quality assessment (IQA) model that combines sensitivity to structural distortions (e.g., artifacts due to noise, blur, or compression) with a tolerance of texture resampling (exchanging the content of a texture region with a new sample of the same texture). The metric is based on a convolutional neural network (CNN) that transforms the reference and distorted images to a new representation. Within this representation, a set of measurements are developed that are sufficient to capture the appearance of a variety of different visual distortions.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): tensor with images of shape(N, 3, H, W)
target
(Tensor
): tensor with images of shape(N, 3, H, W)
As output of forward and compute the metric returns the following output
lpips
(Tensor
): returns float scalar tensor with average LPIPS value over samples
- Parameters:
reduction (
Optional
[Literal
['mean'
,'sum'
]]) – specifies the reduction to apply to the output.kwargs (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ValueError – If reduction is not one of [“mean”, “sum”]
Example
>>> from torch import rand >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity >>> metric = DeepImageStructureAndTextureSimilarity() >>> preds = rand(10, 3, 100, 100) >>> target = rand(10, 3, 100, 100) >>> metric(preds, target) tensor(0.1882, grad_fn=<CloneBackward0>)
- plot(val=None, ax=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity >>> metric = DeepImageStructureAndTextureSimilarity() >>> metric.update(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.dists import DeepImageStructureAndTextureSimilarity >>> metric = DeepImageStructureAndTextureSimilarity() >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.rand(10, 3, 100, 100), torch.rand(10, 3, 100, 100))) >>> fig_, ax_ = metric.plot(values)
Functional Interface
- torchmetrics.functional.image.dists.deep_image_structure_and_texture_similarity(preds, target, reduction=None)[source]
Calculates Deep Image Structure and Texture Similarity (DISTS) score.
- Parameters:
- Return type:
- Returns:
DISTS Similarity score between the two images.
Example
>>> from torch import rand >>> preds = rand(5, 3, 256, 256) >>> target = rand(5, 3, 256, 256) >>> deep_image_structure_and_texture_similarity(preds, target) tensor([0.1285, 0.1344, 0.1356, 0.1277, 0.1276], grad_fn=<RsubBackward1>) >>> deep_image_structure_and_texture_similarity(preds, target, reduction='mean') tensor(0.1308, grad_fn=<MeanBackward0>)