Spatial Distortion Index¶

Module Interface¶

class torchmetrics.image.SpatialDistortionIndex(norm_order=1, window_size=7, reduction='elementwise_mean', **kwargs)[source]

Compute Spatial Distortion Index (SpatialDistortionIndex) also now as D_s.

The metric is used to compare the spatial distortion between two images. A value of 0 indicates no distortion (optimal value) and corresponds to the case where the high resolution panchromatic image is equal to the low resolution panchromatic image. The metric is defined as:

$\begin{split}D_s = \\sqrt[q]{\frac{1}{L}\\sum_{l=1}^L|Q(\\hat{G_l}, P) - Q(\tilde{G}, \tilde{P})|^q}\end{split}$

where $$Q$$ is the universal image quality index (see this UniversalImageQualityIndex for more info), $$\\hat{G_l}$$ is the l-th band of the high resolution multispectral image, $$\tilde{G}$$ is the high resolution panchromatic image, $$P$$ is the high resolution panchromatic image, $$\tilde{P}$$ is the low resolution panchromatic image, $$L$$ is the number of bands and $$q$$ is the order of the norm applied on the difference.

As input to forward and update the metric accepts the following input

• preds (Tensor): High resolution multispectral image of shape (N,C,H,W).

• target (Dict): A dictionary containing the following keys:
• ms (Tensor): Low resolution multispectral image of shape (N,C,H',W').

• pan (Tensor): High resolution panchromatic image of shape (N,C,H,W).

• pan_lr (Tensor): Low resolution panchromatic image of shape (N,C,H',W').

where H and W must be multiple of H’ and W’.

As output of forward and compute the metric returns the following output

• sdi (Tensor): if reduction!='none' returns float scalar tensor with average SDI value over sample else returns tensor of shape (N,) with SDI values per sample

Parameters:
• norm_order (int) – Order of the norm applied on the difference.

• window_size (int) – Window size of the filter applied to degrade the high resolution panchromatic image.

• reduction (Literal['elementwise_mean', 'sum', 'none']) –

a method to reduce metric score over labels.

• 'elementwise_mean': takes the mean (default)

• 'sum': takes the sum

• 'none': no reduction will be applied

• kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Example

>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> target = {
...     'ms': torch.rand([16, 3, 16, 16]),
...     'pan': torch.rand([16, 3, 32, 32]),
... }
>>> sdi = SpatialDistortionIndex()
>>> sdi(preds, target)
tensor(0.0090)

plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
Return type:
Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> target = {
...     'ms': torch.rand([16, 3, 16, 16]),
...     'pan': torch.rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()

>>> # Example plotting multiple values
>>> import torch
>>> _ = torch.manual_seed(42)
>>> from torchmetrics.image import SpatialDistortionIndex
>>> preds = torch.rand([16, 3, 32, 32])
>>> target = {
...     'ms': torch.rand([16, 3, 16, 16]),
...     'pan': torch.rand([16, 3, 32, 32]),
... }
>>> metric = SpatialDistortionIndex()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)


Functional Interface¶

torchmetrics.functional.image.spatial_distortion_index(preds, ms, pan, pan_lr=None, norm_order=1, window_size=7, reduction='elementwise_mean')[source]

Calculate Spatial Distortion Index (SpatialDistortionIndex) also known as D_s.

Metric is used to compare the spatial distortion between two images.

Parameters:
Return type:

Tensor

Returns:

Tensor with SpatialDistortionIndex score

Raises:
• TypeError – If preds, ms, pan and pan_lr don’t have the same data type.

• ValueError – If preds, ms, pan and pan_lr don’t have BxCxHxW shape.

• ValueError – If preds, ms, pan and pan_lr don’t have the same batch and channel sizes.

• ValueError – If preds and pan don’t have the same dimension.

• ValueError – If ms and pan_lr don’t have the same dimension.

• ValueError – If preds and pan don’t have dimension which is multiple of that of ms.

• ValueError – If norm_order is not a positive integer.

• ValueError – If window_size is not a positive integer.

Example

>>> from torchmetrics.functional.image import spatial_distortion_index
>>> _ = torch.manual_seed(42)
>>> preds = torch.rand([16, 3, 32, 32])
>>> ms = torch.rand([16, 3, 16, 16])
>>> pan = torch.rand([16, 3, 32, 32])
>>> spatial_distortion_index(preds, ms, pan)
tensor(0.0090)