CLIPScore

The CLIPScore is a model-based image captioning metric that correlates well with human judgments.

The benefit of CLIPScore is that it does not require reference captions for evaluation.

Here’s a hypothetical Python example demonstrating the usage of the CLIPScore metric to evaluate image captions:

12 import matplotlib.animation as animation
13 import matplotlib.pyplot as plt
14 import numpy as np
15 import torch
16 from matplotlib.table import Table
17 from skimage.data import astronaut, cat, coffee
18
19 from torchmetrics.multimodal import CLIPScore

Get sample images

24 images = {
25     "astronaut": astronaut(),
26     "cat": cat(),
27     "coffee": coffee(),
28 }

Define a hypothetical captions for the images

33 captions = [
34     "A photo of an astronaut.",
35     "A photo of a cat.",
36     "A photo of a cup of coffee.",
37 ]

Define the models for CLIPScore

42 models = [
43     "openai/clip-vit-base-patch16",
44     # "openai/clip-vit-base-patch32",
45     # "openai/clip-vit-large-patch14-336",
46     "openai/clip-vit-large-patch14",
47 ]

Collect scores for each image-caption pair

52 score_results = []
53 for model in models:
54     clip_score = CLIPScore(model_name_or_path=model)
55     for key, img in images.items():
56         img_tensor = torch.tensor(np.array(img))
57         caption_scores = {caption: clip_score(img_tensor, caption) for caption in captions}
58         score_results.append({"scores": caption_scores, "image": key, "model": model})

Create an animation to display the scores

63 fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5))
64
65
66 def update(num: int) -> tuple:
67     """Update the image and table with the scores for the given model."""
68     results = score_results[num]
69     scores, image, model = results["scores"], results["image"], results["model"]
70
71     fig.suptitle(f"Model: {model.split('/')[-1]}", fontsize=16, fontweight="bold")
72
73     # Update image
74     ax_img.imshow(images[image])
75     ax_img.axis("off")
76
77     # Update table
78     table = Table(ax_table, bbox=[0, 0, 1, 1])
79     header1 = table.add_cell(0, 0, text="Caption", width=3, height=1)
80     header2 = table.add_cell(0, 1, text="Score", width=1, height=1)
81     header1.get_text().set_weight("bold")
82     header2.get_text().set_weight("bold")
83     for i, (caption, score) in enumerate(scores.items()):
84         table.add_cell(i + 1, 0, text=caption, width=3, height=1)
85         table.add_cell(i + 1, 1, text=f"{score:.2f}", width=1, height=1)
86     ax_table.add_table(table)
87     ax_table.axis("off")
88     return ax_img, ax_table
89
90
91 ani = animation.FuncAnimation(fig, update, frames=len(score_results), interval=3000)

Total running time of the script: (0 minutes 44.891 seconds)

Gallery generated by Sphinx-Gallery