ARNIQA

Module Interface

class torchmetrics.image.arniqa.ARNIQA(regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False, **kwargs)[source]

ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.

ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.

The input image is expected to have shape (N, 3, H, W). The image should be in the [0, 1] range if normalize is set to True, otherwise it should be normalized with the ImageNet mean and standard deviation.

Note

Using this metric requires you to have torchvision package installed. Either install as pip install torchmetrics[image] or pip install torchvision.

As input to forward and update the metric accepts the following input

  • img (Tensor): tensor with images of shape (N, 3, H, W)

As output of forward and compute the metric returns the following output

  • arniqa (Tensor): tensor with ARNIQA score. If reduction is set to none, the output will have shape (N,), otherwise it will be a scalar tensor. Tensor values are in the [0, 1] range, where higher is better.

Parameters:
  • img – the input image

  • regressor_dataset (Literal['kadid10k', 'koniq10k']) – dataset used for training the regressor. Choose between [koniq10k, kadid10k]. koniq10k corresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions. kadid10k corresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.

  • reduction (Literal['sum', 'mean', 'none']) – indicates how to reduce over the batch dimension. Choose between [sum, mean, none].

  • normalize (bool) – by default this is True meaning that the input is expected to be in the [0, 1] range. If set to False will instead expect input to be already normalized with the ImageNet mean and standard deviation.

  • autocast (bool) – if True, metric will convert model to mixed precision before running forward pass.

  • kwargs (Any) – additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • ModuleNotFoundError – If torchvision package is not installed

  • ValueError – If regressor_dataset is not in ["kadid10k", "koniq10k"]

  • ValueError – If reduction is not in ["sum", "mean", "none"]

  • ValueError – If normalize is not a bool

  • ValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].

  • ValueError – If the input image values are not in the [0, 1] range when normalize is set to True

Examples

>>>
>>> from torch import rand
>>> from torchmetrics.image.arniqa import ARNIQA
>>> img = rand(8, 3, 224, 224)
>>> # Non-normalized input
>>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=True)
>>> metric(img)
tensor(0.5308)
>>>
>>> from torch import rand
>>> from torchmetrics.image.arniqa import ARNIQA
>>> from torchvision.transforms import Normalize
>>> img = rand(8, 3, 224, 224)
>>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
>>> # Normalized input
>>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=False)
>>> metric(img)
tensor(0.5065)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>>
>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.image.arniqa import ARNIQA
>>> metric = ARNIQA(regressor_dataset='koniq10k')
>>> metric.update(torch.rand(8, 3, 224, 224))
>>> fig_, ax_ = metric.plot()
../_images/arniqa-1.png
>>>
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.image.arniqa import ARNIQA
>>> metric = ARNIQA(regressor_dataset='koniq10k')
>>> values = [ ]
>>> for _ in range(3):
...     values.append(metric(torch.rand(8, 3, 224, 224)))
>>> fig_, ax_ = metric.plot(values)
../_images/arniqa-2.png

Functional Interface

torchmetrics.functional.image.arniqa(img, regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False)[source]

ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.

ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.

The input image is expected to have shape (N, 3, H, W). The image should be in the [0, 1] range if normalize is set to True, otherwise it should be normalized with the ImageNet mean and standard deviation.

Note

Using this metric requires you to have torchvision package installed. Either install as pip install torchmetrics[image] or pip install torchvision.

Parameters:
  • img (Tensor) – the input image

  • regressor_dataset (Literal['kadid10k', 'koniq10k']) – dataset used for training the regressor. Choose between [koniq10k, kadid10k]. koniq10k corresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions. kadid10k corresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.

  • reduction (Literal['sum', 'mean', 'none']) – indicates how to reduce over the batch dimension. Choose between [sum, mean, none].

  • normalize (bool) – by default this is True meaning that the input is expected to be in the [0, 1] range. If set to False will instead expect input to be already normalized with the ImageNet mean and standard deviation.

  • autocast (bool) – boolean indicating whether to use automatic mixed precision

Return type:

Tensor

Returns:

A tensor in the [0, 1] range, where higher is better, representing the ARNIQA score of the input image. If reduction is set to none, the output will have shape (N,), otherwise it will be a scalar tensor.

Raises:
  • ModuleNotFoundError – If torchvision package is not installed

  • ValueError – If regressor_dataset is not in ["kadid10k", "koniq10k"]

  • ValueError – If reduction is not in ["sum", "mean", "none"]

  • ValueError – If normalize is not a bool

  • ValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].

  • ValueError – If the input image values are not in the [0, 1] range when normalize is set to True

Examples

>>>
>>> from torch import rand
>>> from torchmetrics.functional.image.arniqa import arniqa
>>> img = rand(8, 3, 224, 224)
>>> # Non-normalized input
>>> arniqa(img, regressor_dataset='koniq10k', normalize=True)
tensor(0.5308)
>>>
>>> from torch import rand
>>> from torchmetrics.functional.image.arniqa import arniqa
>>> from torchvision.transforms import Normalize
>>> img = rand(8, 3, 224, 224)
>>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
>>> # Normalized input
>>> arniqa(img, regressor_dataset='koniq10k', normalize=False)
tensor(0.5065)

You are viewing an outdated version of TorchMetrics Docs

Click here to view the latest version→