{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "\n# Perplexity\n\nPerplexity is a measure of how well a probabilistic model predicts a sample.\n\nIn the context of language modeling, perplexity equals the exponential of the cross-entropy loss. A lower perplexity score indicates that the model is more certain about its predictions.\nSince Perplexity measures token probabilities, it is not suitable for evaluating decoding tasks like text generation or machine translation. Instead, it is commonly used to evaluate the logits of generative language models.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here's a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "import torch\nfrom torchmetrics.text import Perplexity\nfrom transformers import AutoModelWithLMHead, AutoTokenizer" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load the GPT-2 model and tokenizer\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "model = AutoModelWithLMHead.from_pretrained(\"gpt2\")\ntokenizer = AutoTokenizer.from_pretrained(\"gpt2\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Generate token logits for a sample text\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "sample_text = \"The quick brown fox jumps over the lazy dog\"\nsample_input_ids = tokenizer.encode(sample_text, return_tensors=\"pt\")\n\nwith torch.no_grad():\n sample_outputs = model(sample_input_ids, labels=sample_input_ids)\nlogits = sample_outputs.logits" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We can now calculate the perplexity of the logits\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "perplexity = Perplexity()\nscore = perplexity(preds=logits, target=sample_input_ids)\nprint(f\"Perplexity, unshifted: {score.item()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "This perplexity score is suspiciously high. The cause of this is that the model labels need to be shifted to the right by one position. We can fix this by removing the first token from the logits and the last token from the target\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "score = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])\nprint(f\"Perplexity, shifted: {score.item()}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the perplexity equates to the exponential of the cross-entropy loss, we can verify the perplexity calculation by comparing it to the loss\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "cross_entropy = score\nperplexity = sample_outputs.loss.exp()\nprint(torch.allclose(perplexity, cross_entropy))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Be aware that sequences are often padded to ensure equal length. In such cases, the padding tokens should be ignored when calculating the perplexity. This can be achieved by specifying the `ignore_index` argument in the `Perplexity` metric\n\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": false }, "outputs": [], "source": [ "tokenizer.pad_token_id = tokenizer.eos_token_id\nsample_input_ids = tokenizer.encode(sample_text, return_tensors=\"pt\", padding=\"max_length\", max_length=20)\nwith torch.no_grad():\n sample_outputs = model(sample_input_ids, labels=sample_input_ids)\nlogits = sample_outputs.logits\n\nperplexity = Perplexity(ignore_index=None)\nscore = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])\nprint(f\"Perplexity, including padding: {score.item()}\")\n\nperplexity = Perplexity(ignore_index=tokenizer.pad_token_id)\nscore = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])\nprint(f\"Perplexity, ignoring padding: {score.item()}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.19" } }, "nbformat": 4, "nbformat_minor": 0 }