CLIPScore

The CLIPScore is a model-based image captioning metric that correlates well with human judgments.

The benefit of CLIPScore is that it does not require reference captions for evaluation.

Here’s a hypothetical Python example demonstrating the usage of the CLIPScore metric to evaluate image captions:

12 import matplotlib.animation as animation
13 import matplotlib.pyplot as plt
14 import numpy as np
15 import torch
16 from matplotlib.table import Table
17 from skimage.data import astronaut, cat, coffee
18 from torchmetrics.multimodal import CLIPScore

Get sample images

23 images = {
24     "astronaut": astronaut(),
25     "cat": cat(),
26     "coffee": coffee(),
27 }

Define a hypothetical captions for the images

32 captions = [
33     "A photo of an astronaut.",
34     "A photo of a cat.",
35     "A photo of a cup of coffee.",
36 ]

Define the models for CLIPScore

41 models = [
42     "openai/clip-vit-base-patch16",
43     # "openai/clip-vit-base-patch32",
44     # "openai/clip-vit-large-patch14-336",
45     "openai/clip-vit-large-patch14",
46 ]

Collect scores for each image-caption pair

51 score_results = []
52 for model in models:
53     clip_score = CLIPScore(model_name_or_path=model)
54     for key, img in images.items():
55         img_tensor = torch.tensor(np.array(img))
56         caption_scores = {caption: clip_score(img_tensor, caption) for caption in captions}
57         score_results.append({"scores": caption_scores, "image": key, "model": model})

Create an animation to display the scores

62 fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5))
63
64
65 def update(num: int) -> tuple:
66     """Update the image and table with the scores for the given model."""
67     results = score_results[num]
68     scores, image, model = results["scores"], results["image"], results["model"]
69
70     fig.suptitle(f"Model: {model.split('/')[-1]}", fontsize=16, fontweight="bold")
71
72     # Update image
73     ax_img.imshow(images[image])
74     ax_img.axis("off")
75
76     # Update table
77     table = Table(ax_table, bbox=[0, 0, 1, 1])
78     header1 = table.add_cell(0, 0, text="Caption", width=3, height=1)
79     header2 = table.add_cell(0, 1, text="Score", width=1, height=1)
80     header1.get_text().set_weight("bold")
81     header2.get_text().set_weight("bold")
82     for i, (caption, score) in enumerate(scores.items()):
83         table.add_cell(i + 1, 0, text=caption, width=3, height=1)
84         table.add_cell(i + 1, 1, text=f"{score:.2f}", width=1, height=1)
85     ax_table.add_table(table)
86     ax_table.axis("off")
87     return ax_img, ax_table
88
89
90 ani = animation.FuncAnimation(fig, update, frames=len(score_results), interval=3000)

Total running time of the script: (0 minutes 43.883 seconds)

Gallery generated by Sphinx-Gallery