CLIP Image Quality Assessment (CLIP-IQA)¶
Module Interface¶
- class torchmetrics.multimodal.CLIPImageQualityAssessment(model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',), **kwargs)[source]
Calculates CLIP-IQA, that can be used to measure the visual content of images.
The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.
The metric works by calculating the cosine similarity between user provided images and pre-defined prompts. The prompts always comes in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.
- Build in prompts are:
quality: “Good photo.” vs “Bad photo.”
brightness: “Bright photo.” vs “Dark photo.”
noisiness: “Clean photo.” vs “Noisy photo.”
colorfullness: “Colorful photo.” vs “Dull photo.”
sharpness: “Sharp photo.” vs “Blurry photo.”
contrast: “High contrast photo.” vs “Low contrast photo.”
complexity: “Complex photo.” vs “Simple photo.”
natural: “Natural photo.” vs “Synthetic photo.”
happy: “Happy photo.” vs “Sad photo.”
scary: “Scary photo.” vs “Peaceful photo.”
new: “New photo.” vs “Old photo.”
warm: “Warm photo.” vs “Cold photo.”
real: “Real photo.” vs “Abstract photo.”
beautiful: “Beautiful photo.” vs “Ugly photo.”
lonely: “Lonely photo.” vs “Sociable photo.”
relaxing: “Relaxing photo.” vs “Stressful photo.”
As input to
forward
andupdate
the metric accepts the following inputimages
(Tensor
): tensor with images feed to the feature extractor with shape(N,C,H,W)
As output of forward and compute the metric returns the following output
clip_iqa
(Tensor
or dict of tensors): tensor with the CLIP-IQA score. If a single prompt is provided, a single tensor with shape(N,)
is returned. If a list of prompts is provided, a dict of tensors is returned with the prompt as key and the tensor with shape(N,)
as value.
- Parameters:
model_name_or_path¶ (
Literal
['clip_iqa'
,'openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) –string indicating the version of the CLIP model to use. Available models are:
”clip_iqa”, model corresponding to the CLIP-IQA paper.
”openai/clip-vit-base-patch16”
”openai/clip-vit-base-patch32”
”openai/clip-vit-large-patch14-336”
”openai/clip-vit-large-patch14”
data_range¶ (
float
) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.prompts¶ (
Tuple
[Union
[str
,Tuple
[str
,str
]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the available prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Note
If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[image].
- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
ValueError – If prompts is a tuple and it is not of length 2
ValueError – If prompts is a string and it is not one of the available prompts
ValueError – If prompts is a list of strings and not all strings are one of the available prompts
- Example::
Single prompt:
>>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment() >>> metric(imgs) tensor([0.8894, 0.8902])
- Example::
Multiple prompts:
>>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=("quality", "brightness")) >>> metric(imgs) {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])}
- Example::
Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.
>>> from torch import randint >>> from torchmetrics.multimodal import CLIPImageQualityAssessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> metric = CLIPImageQualityAssessment(prompts=(("Super good photo.", "Super bad photo."), "brightness")) >>> metric(imgs) {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])}
- plot(val=None, ax=None)[source]
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment >>> metric = CLIPImageQualityAssessment() >>> metric.update(torch.rand(1, 3, 224, 224)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.multimodal.clip_iqa import CLIPImageQualityAssessment >>> metric = CLIPImageQualityAssessment() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(1, 3, 224, 224))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.multimodal.clip_image_quality_assessment(images, model_name_or_path='clip_iqa', data_range=1.0, prompts=('quality',))[source]
Calculates CLIP-IQA, that can be used to measure the visual content of images.
The metric is based on the CLIP model, which is a neural network trained on a variety of (image, text) pairs to be able to generate a vector representation of the image and the text that is similar if the image and text are semantically similar.
The metric works by calculating the cosine similarity between user provided images and pre-defined prompts. The prompts always come in pairs of “positive” and “negative” such as “Good photo.” and “Bad photo.”. By calculating the similartity between image embeddings and both the “positive” and “negative” prompt, the metric can determine which prompt the image is more similar to. The metric then returns the probability that the image is more similar to the first prompt than the second prompt.
- Build in prompts are:
quality: “Good photo.” vs “Bad photo.”
brightness: “Bright photo.” vs “Dark photo.”
noisiness: “Clean photo.” vs “Noisy photo.”
colorfullness: “Colorful photo.” vs “Dull photo.”
sharpness: “Sharp photo.” vs “Blurry photo.”
contrast: “High contrast photo.” vs “Low contrast photo.”
complexity: “Complex photo.” vs “Simple photo.”
natural: “Natural photo.” vs “Synthetic photo.”
happy: “Happy photo.” vs “Sad photo.”
scary: “Scary photo.” vs “Peaceful photo.”
new: “New photo.” vs “Old photo.”
warm: “Warm photo.” vs “Cold photo.”
real: “Real photo.” vs “Abstract photo.”
beautiful: “Beautiful photo.” vs “Ugly photo.”
lonely: “Lonely photo.” vs “Sociable photo.”
relaxing: “Relaxing photo.” vs “Stressful photo.”
- Parameters:
images¶ (
Tensor
) – Either a single[N, C, H, W]
tensor or a list of[C, H, W]
tensorsmodel_name_or_path¶ (
Literal
['clip_iqa'
,'openai/clip-vit-base-patch16'
,'openai/clip-vit-base-patch32'
,'openai/clip-vit-large-patch14-336'
,'openai/clip-vit-large-patch14'
]) – string indicating the version of the CLIP model to use. By default this argument is set toclip_iqa
which corresponds to the model used in the original paper. Other available models are “openai/clip-vit-base-patch16”, “openai/clip-vit-base-patch32”, “openai/clip-vit-large-patch14-336” and “openai/clip-vit-large-patch14”data_range¶ (
float
) – The maximum value of the input tensor. For example, if the input images are in range [0, 255], data_range should be 255. The images are normalized by this value.prompts¶ (
Tuple
[Union
[str
,Tuple
[str
,str
]]]) – A string, tuple of strings or nested tuple of strings. If a single string is provided, it must be one of the available prompts (see above). Else the input is expected to be a tuple, where each element can be one of two things: either a string or a tuple of strings. If a string is provided, it must be one of the available prompts (see above). If tuple is provided, it must be of length 2 and the first string must be a positive prompt and the second string must be a negative prompt.
Note
If using the default clip_iqa model, the package piq must be installed. Either install with pip install piq or pip install torchmetrics[multimodal].
- Return type:
- Returns:
A tensor of shape
(N,)
if a single prompts is provided. If a list of prompts is provided, a dictionary of with the prompts as keys and tensors of shape(N,)
as values.- Raises:
ModuleNotFoundError – If transformers package is not installed or version is lower than 4.10.0
ValueError – If not all images have format [C, H, W]
ValueError – If prompts is a tuple and it is not of length 2
ValueError – If prompts is a string and it is not one of the available prompts
ValueError – If prompts is a list of strings and not all strings are one of the available prompts
- Example::
Single prompt:
>>> from torch import randint >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality",)) tensor([0.8894, 0.8902])
- Example::
Multiple prompts:
>>> from torch import randint >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=("quality", "brightness")) {'quality': tensor([0.8693, 0.8705]), 'brightness': tensor([0.5722, 0.4762])}
- Example::
Custom prompts. Must always be a tuple of length 2, with a positive and negative prompt.
>>> from torch import rand >>> from torchmetrics.functional.multimodal import clip_image_quality_assessment >>> imgs = randint(255, (2, 3, 224, 224)).float() >>> clip_image_quality_assessment(imgs, prompts=(("Super good photo.", "Super bad photo."), "brightness")) {'user_defined_0': tensor([0.9578, 0.9654]), 'brightness': tensor([0.5495, 0.5764])}