{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# ROUGE\n\nThe ROUGE (Recall-Oriented Understudy for Gisting Evaluation) metric used to evaluate the quality of generated text compared to a reference text. It does so by computing the overlap between two texts, for which a subsequent precision and recall value can be computed. The ROUGE score is often used in the context of generative tasks such as text summarization and machine translation.\n\nA major difference with Perplexity comes from the fact that ROUGE evaluates actual text, whereas Perplexity evaluates logits.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a hypothetical Python example demonstrating the usage of unigram ROUGE F-score to evaluate a generative language model:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "from torchmetrics.text import ROUGEScore\nfrom transformers import AutoTokenizer, pipeline\n\npipe = pipeline(\"text-generation\", model=\"openai-community/gpt2\")\ntokenizer = AutoTokenizer.from_pretrained(\"openai-community/gpt2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Define the prompt and target texts\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "prompt = \"The quick brown fox\"\ntarget_text = \"The quick brown fox jumps over the lazy dog.\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate a sample text using the GPT-2 model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sample_text = pipe(prompt, max_length=20, do_sample=True, temperature=0.1, pad_token_id=tokenizer.eos_token_id)[0][\n \"generated_text\"\n]\nprint(sample_text)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calculate the ROUGE of the generated text\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "rouge = ROUGEScore()\nrouge(preds=[sample_text], target=[target_text])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "By default, the ROUGE score is calculated using a whitespace tokenizer. You can also calculate the ROUGE for the tokens directly:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "token_rouge = ROUGEScore(tokenizer=lambda text: tokenizer.tokenize(text))\ntoken_rouge(preds=[sample_text], target=[target_text])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since ROUGE is a text-based metric, it can be used to benchmark decoding strategies. For example, you can compare temperature settings:\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import matplotlib.pyplot as plt # noqa: E402\n\ntemperatures = [x * 0.1 for x in range(1, 10)] # Generate temperature values from 0 to 1 with a step of 0.1\nn_samples = 100 # Note that a real benchmark typically requires more data\n\naverage_scores = []\n\nfor temperature in temperatures:\n sample_text = pipe(\n prompt, max_length=20, do_sample=True, temperature=temperature, pad_token_id=tokenizer.eos_token_id\n )[0][\"generated_text\"]\n scores = [rouge(preds=[sample_text], target=[target_text])[\"rouge1_fmeasure\"] for _ in range(n_samples)]\n average_scores.append(sum(scores) / n_samples)\n\n# Plot the average ROUGE score for each temperature\nplt.plot(temperatures, average_scores)\nplt.xlabel(\"Generation temperature\")\nplt.ylabel(\"Average unigram ROUGE F-Score\")\nplt.title(\"ROUGE for varying temperature settings\")\nplt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }