{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# CLIPScore\n\nThe CLIPScore is a model-based image captioning metric that correlates well with human judgments.\n\nThe benefit of CLIPScore is that it does not require reference captions for evaluation.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a hypothetical Python example demonstrating the usage of the CLIPScore metric to evaluate image captions:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import warnings\n\nimport matplotlib.animation as animation\nimport matplotlib.pyplot as plt\nimport numpy as np\nimport torch\nfrom matplotlib.table import Table\nfrom skimage.data import astronaut, cat, coffee\n\nfrom torchmetrics.multimodal import CLIPScore" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Get sample images\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "images = {\n \"astronaut\": astronaut(),\n \"cat\": cat(),\n \"coffee\": coffee(),\n}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define a hypothetical captions for the images\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "captions = [\n \"A photo of an astronaut.\",\n \"A photo of a cat.\",\n \"A photo of a cup of coffee.\",\n]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the models for CLIPScore\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "models = [\n \"openai/clip-vit-base-patch16\",\n # \"zer0int/LongCLIP-L-Diffusers\",\n]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Collect scores for each image-caption pair\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "score_results = []\n\n\ndef process_model(model):\n \"\"\"Process a CLIP model by evaluating image-caption pairs and recording scores.\n\n Args:\n model: The name or path of the CLIP model to use for evaluation\n\n This function handles exceptions if the model fails to load or process,\n allowing the program to continue with other models.\n\n \"\"\"\n try:\n clip_score = CLIPScore(model_name_or_path=model)\n for key, img in images.items():\n img_tensor = torch.tensor(np.array(img))\n caption_scores = {caption: clip_score(img_tensor, caption) for caption in captions}\n score_results.append({\"scores\": caption_scores, \"image\": key, \"model\": model})\n except Exception as e:\n warnings.warn(f\"Error loading model {model} - skipping this test. Error details: {e}\", stacklevel=2)\n\n\nfor model in models:\n process_model(model)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Create an animation to display the scores\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "fig, (ax_img, ax_table) = plt.subplots(1, 2, figsize=(10, 5))\n\n\ndef update(num: int) -> tuple:\n \"\"\"Update the image and table with the scores for the given model.\"\"\"\n results = score_results[num]\n scores, image, model = results[\"scores\"], results[\"image\"], results[\"model\"]\n\n fig.suptitle(f\"Model: {model.split('/')[-1]}\", fontsize=16, fontweight=\"bold\")\n\n # Update image\n ax_img.imshow(images[image])\n ax_img.axis(\"off\")\n\n # Update table\n table = Table(ax_table, bbox=[0, 0, 1, 1])\n header1 = table.add_cell(0, 0, text=\"Caption\", width=3, height=1)\n header2 = table.add_cell(0, 1, text=\"Score\", width=1, height=1)\n header1.get_text().set_weight(\"bold\")\n header2.get_text().set_weight(\"bold\")\n for i, (caption, score) in enumerate(scores.items()):\n table.add_cell(i + 1, 0, text=caption, width=3, height=1)\n table.add_cell(i + 1, 1, text=f\"{score:.2f}\", width=1, height=1)\n ax_table.add_table(table)\n ax_table.axis(\"off\")\n return ax_img, ax_table\n\n\nani = animation.FuncAnimation(fig, update, frames=len(score_results), interval=3000)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.18" } }, "nbformat": 4, "nbformat_minor": 0 }