CLIPScore

The CLIPScore is a model-based image captioning metric that correlates well with human judgments.

The benefit of CLIPScore is that it does not require reference captions for evaluation.

Here’s a hypothetical Python example demonstrating the usage of the CLIPScore metric to evaluate image captions:

12 import matplotlib.animation as animation
13 import matplotlib.pyplot as plt
14 import numpy as np
15 import torch
16 from matplotlib.table import Table
17 from skimage.data import astronaut, cat, coffee
18
19 from torchmetrics.multimodal import CLIPScore

Get sample images

24 images = {
25     "astronaut": astronaut(),
26     "cat": cat(),
27     "coffee": coffee(),
28 }

Define a hypothetical captions for the images

33 captions = [
34     "A photo of an astronaut.",
35     "A photo of a cat.",
36     "A photo of a cup of coffee.",
37 ]

Define the models for CLIPScore

42 models = [
43     "openai/clip-vit-base-patch16",
44     # "zer0int/LongCLIP-L-Diffusers",
45 ]

Collect scores for each image-caption pair

50 score_results = []
51 for model in models:
52     clip_score = CLIPScore(model_name_or_path=model)
53     for key, img in images.items():
54         img_tensor = torch.tensor(np.array(img))
55         caption_scores = {caption: clip_score(img_tensor, caption) for caption in captions}
56         score_results.append({"scores": caption_scores, "image": key, "model": model})

Create an animation to display the scores

61 fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5))
62
63
64 def update(num: int) -> tuple:
65     """Update the image and table with the scores for the given model."""
66     results = score_results[num]
67     scores, image, model = results["scores"], results["image"], results["model"]
68
69     fig.suptitle(f"Model: {model.split('/')[-1]}", fontsize=16, fontweight="bold")
70
71     # Update image
72     ax_img.imshow(images[image])
73     ax_img.axis("off")
74
75     # Update table
76     table = Table(ax_table, bbox=[0, 0, 1, 1])
77     header1 = table.add_cell(0, 0, text="Caption", width=3, height=1)
78     header2 = table.add_cell(0, 1, text="Score", width=1, height=1)
79     header1.get_text().set_weight("bold")
80     header2.get_text().set_weight("bold")
81     for i, (caption, score) in enumerate(scores.items()):
82         table.add_cell(i + 1, 0, text=caption, width=3, height=1)
83         table.add_cell(i + 1, 1, text=f"{score:.2f}", width=1, height=1)
84     ax_table.add_table(table)
85     ax_table.axis("off")
86     return ax_img, ax_table
87
88
89 ani = animation.FuncAnimation(fig, update, frames=len(score_results), interval=3000)

Total running time of the script: (0 minutes 11.606 seconds)

Gallery generated by Sphinx-Gallery

You are viewing an outdated version of TorchMetrics Docs

Click here to view the latest version→