ARNIQA
Module Interface
- class torchmetrics.image.arniqa.ARNIQA(regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False, **kwargs)[source]
ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
The input image is expected to have shape
(N, 3, H, W)
. The image should be in the [0, 1] range if normalize is set toTrue
, otherwise it should be normalized with the ImageNet mean and standard deviation.Note
Using this metric requires you to have
torchvision
package installed. Either install aspip install torchmetrics[image]
orpip install torchvision
.As input to
forward
andupdate
the metric accepts the following inputimg
(Tensor
): tensor with images of shape(N, 3, H, W)
As output of forward and compute the metric returns the following output
arniqa
(Tensor
): tensor with ARNIQA score. If reduction is set tonone
, the output will have shape(N,)
, otherwise it will be a scalar tensor. Tensor values are in the [0, 1] range, where higher is better.
- Parameters:
regressor_dataset (
Literal
['kadid10k'
,'koniq10k'
]) – dataset used for training the regressor. Choose between [koniq10k
,kadid10k
].koniq10k
corresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions.kadid10k
corresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.reduction (
Literal
['sum'
,'mean'
,'none'
]) – indicates how to reduce over the batch dimension. Choose between [sum
,mean
,none
].normalize (
bool
) – by default this isTrue
meaning that the input is expected to be in the [0, 1] range. If set toFalse
will instead expect input to be already normalized with the ImageNet mean and standard deviation.autocast (
bool
) – ifTrue
, metric will convert model to mixed precision before running forward pass.kwargs (
Any
) – additional keyword arguments, see Advanced metric settings for more info.
- Raises:
ModuleNotFoundError – If
torchvision
package is not installedValueError – If
regressor_dataset
is not in ["kadid10k"
,"koniq10k"
]ValueError – If
reduction
is not in ["sum"
,"mean"
,"none"
]ValueError – If
normalize
is not a boolValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].
ValueError – If the input image values are not in the [0, 1] range when
normalize
is set toTrue
Examples
>>> from torch import rand >>> from torchmetrics.image.arniqa import ARNIQA >>> img = rand(8, 3, 224, 224) >>> # Non-normalized input >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=True) >>> metric(img) tensor(0.5308)
>>> from torch import rand >>> from torchmetrics.image.arniqa import ARNIQA >>> from torchvision.transforms import Normalize >>> img = rand(8, 3, 224, 224) >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) >>> # Normalized input >>> metric = ARNIQA(regressor_dataset='koniq10k', normalize=False) >>> metric(img) tensor(0.5065)
- plot(val=None, ax=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.image.arniqa import ARNIQA >>> metric = ARNIQA(regressor_dataset='koniq10k') >>> metric.update(torch.rand(8, 3, 224, 224)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.image.arniqa import ARNIQA >>> metric = ARNIQA(regressor_dataset='koniq10k') >>> values = [ ] >>> for _ in range(3): ... values.append(metric(torch.rand(8, 3, 224, 224))) >>> fig_, ax_ = metric.plot(values)
Functional Interface
- torchmetrics.functional.image.arniqa(img, regressor_dataset='koniq10k', reduction='mean', normalize=True, autocast=False)[source]
ARNIQA: leArning distoRtion maNifold for Image Quality Assessment metric.
ARNIQA is a No-Reference Image Quality Assessment metric that predicts the technical quality of an image with a high correlation with human opinions. ARNIQA consists of an encoder and a regressor. The encoder is a ResNet-50 model trained in a self-supervised way to model the image distortion manifold to generate similar representation for images with similar distortions, regardless of the image content. The regressor is a linear model trained on IQA datasets using the ground-truth quality scores. ARNIQA extracts the features from the full- and half-scale versions of the input image and then outputs a quality score in the [0, 1] range, where higher is better.
The input image is expected to have shape
(N, 3, H, W)
. The image should be in the [0, 1] range if normalize is set toTrue
, otherwise it should be normalized with the ImageNet mean and standard deviation.Note
Using this metric requires you to have
torchvision
package installed. Either install aspip install torchmetrics[image]
orpip install torchvision
.- Parameters:
img (
Tensor
) – the input imageregressor_dataset (
Literal
['kadid10k'
,'koniq10k'
]) – dataset used for training the regressor. Choose between [koniq10k
,kadid10k
].koniq10k
corresponds to the KonIQ-10k dataset, which consists of real-world images with authentic distortions.kadid10k
corresponds to the KADID-10k dataset, which consists of images with synthetically generated distortions.reduction (
Literal
['sum'
,'mean'
,'none'
]) – indicates how to reduce over the batch dimension. Choose between [sum
,mean
,none
].normalize (
bool
) – by default this isTrue
meaning that the input is expected to be in the [0, 1] range. If set toFalse
will instead expect input to be already normalized with the ImageNet mean and standard deviation.autocast (
bool
) – boolean indicating whether to use automatic mixed precision
- Return type:
- Returns:
A tensor in the [0, 1] range, where higher is better, representing the ARNIQA score of the input image. If reduction is set to
none
, the output will have shape(N,)
, otherwise it will be a scalar tensor.- Raises:
ModuleNotFoundError – If
torchvision
package is not installedValueError – If
regressor_dataset
is not in ["kadid10k"
,"koniq10k"
]ValueError – If
reduction
is not in ["sum"
,"mean"
,"none"
]ValueError – If
normalize
is not a boolValueError – If the input image is not a valid image tensor with shape [N, 3, H, W].
ValueError – If the input image values are not in the [0, 1] range when
normalize
is set toTrue
Examples
>>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> img = rand(8, 3, 224, 224) >>> # Non-normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=True) tensor(0.5308)
>>> from torch import rand >>> from torchmetrics.functional.image.arniqa import arniqa >>> from torchvision.transforms import Normalize >>> img = rand(8, 3, 224, 224) >>> img = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) >>> # Normalized input >>> arniqa(img, regressor_dataset='koniq10k', normalize=False) tensor(0.5065)