Multi-Scale SSIM

Module Interface

class torchmetrics.image.MultiScaleStructuralSimilarityIndexMeasure(gaussian_kernel=True, kernel_size=11, sigma=1.5, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu', **kwargs)[source]

Compute MultiScaleSSIM, Multi-scale Structural Similarity Index Measure.

This metric is is a generalization of Structural Similarity Index Measure by incorporating image details at different resolution scores.

As input to forward and update the metric accepts the following input

  • preds (Tensor): Predictions from model

  • target (Tensor): Ground truth values

As output of forward and compute the metric returns the following output

  • msssim (Tensor): if reduction!='none' returns float scalar tensor with average MSSSIM value over sample else returns tensor of shape (N,) with SSIM values per sample

Parameters:
  • gaussian_kernel (bool) – If True (default), a gaussian kernel is used, if false a uniform kernel is used

  • kernel_size (Union[int, Sequence[int]]) – size of the gaussian kernel

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values. The data_range must be given when dim is not None.

  • k1 (float) – Parameter of structural similarity index measure.

  • k2 (float) – Parameter of structural similarity index measure.

  • betas (tuple[float, ...]) – Exponent parameters for individual similarities and contrastive sensitivities returned by different image resolutions.

  • normalize (Literal['relu', 'simple', None]) – When MultiScaleStructuralSimilarityIndexMeasure loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Returns:

Tensor with Multi-Scale SSIM score

Raises:
  • ValueError – If kernel_size is not an int or a Sequence of ints with size 2 or 3.

  • ValueError – If betas is not a tuple of floats with length 2.

  • ValueError – If normalize is neither None, ReLU nor simple.

Example

>>> from torch import rand
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> preds = torch.rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> ms_ssim = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> ms_ssim(preds, target)
tensor(0.9628)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> from torch import rand
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
../_images/multi_scale_structural_similarity-1.png
>>> # Example plotting multiple values
>>> from torch import rand
>>> from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> metric = MultiScaleStructuralSimilarityIndexMeasure(data_range=1.0)
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)
../_images/multi_scale_structural_similarity-2.png

Functional Interface

torchmetrics.functional.image.multiscale_structural_similarity_index_measure(preds, target, gaussian_kernel=True, sigma=1.5, kernel_size=11, reduction='elementwise_mean', data_range=None, k1=0.01, k2=0.03, betas=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333), normalize='relu')[source]

Compute MultiScaleSSIM, Multi-scale Structural Similarity Index Measure.

This metric is a generalization of Structural Similarity Index Measure by incorporating image details at different resolution scores.

Parameters:
  • preds (Tensor) – Predictions from model of shape [N, C, H, W]

  • target (Tensor) – Ground truth values of shape [N, C, H, W]

  • gaussian_kernel (bool) – If true, a gaussian kernel is used, if false a uniform kernel is used

  • sigma (Union[float, Sequence[float]]) – Standard deviation of the gaussian kernel

  • kernel_size (Union[int, Sequence[int]]) – size of the gaussian kernel

  • reduction (Literal['elementwise_mean', 'sum', 'none', None]) –

    a method to reduce metric score over labels.

    • 'elementwise_mean': takes the mean

    • 'sum': takes the sum

    • 'none' or None: no reduction will be applied

  • data_range (Union[float, tuple[float, float], None]) – the range of the data. If None, it is determined from the data (max - min). If a tuple is provided then the range is calculated as the difference and input is clamped between the values.

  • k1 (float) – Parameter of structural similarity index measure.

  • k2 (float) – Parameter of structural similarity index measure.

  • betas (tuple[float, ...]) – Exponent parameters for individual similarities and contrastive sensitivities returned by different image resolutions.

  • normalize (Optional[Literal['relu', 'simple']]) – When MultiScaleSSIM loss is used for training, it is desirable to use normalizes to improve the training stability. This normalize argument is out of scope of the original implementation [1], and it is adapted from https://github.com/jorge-pessoa/pytorch-msssim instead.

Return type:

Tensor

Returns:

Tensor with Multi-Scale SSIM score

Raises:
  • TypeError – If preds and target don’t have the same data type.

  • ValueError – If preds and target don’t have BxCxHxW shape.

  • ValueError – If the length of kernel_size or sigma is not 2.

  • ValueError – If one of the elements of kernel_size is not an odd positive number.

  • ValueError – If one of the elements of sigma is not a positive number.

Example

>>> from torch import rand
>>> from torchmetrics.functional.image import multiscale_structural_similarity_index_measure
>>> preds = rand([3, 3, 256, 256])
>>> target = preds * 0.75
>>> multiscale_structural_similarity_index_measure(preds, target, data_range=1.0)
tensor(0.9628)

References

[1] Multi-Scale Structural Similarity For Image Quality Assessment by Zhou Wang, Eero P. Simoncelli and Alan C. Bovik MultiScaleSSIM