Note
Go to the end to download the full example code.
Perplexity¶
Perplexity is a measure of how well a probabilistic model predicts a sample.
In the context of language modeling, perplexity equals the exponential of the cross-entropy loss. A lower perplexity score indicates that the model is more certain about its predictions. Since Perplexity measures token probabilities, it is not suitable for evaluating decoding tasks like text generation or machine translation. Instead, it is commonly used to evaluate the logits of generative language models.
Here’s a hypothetical Python example demonstrating the usage of Perplexity to evaluate a generative language model
14 import torch
15 from transformers import AutoModelWithLMHead, AutoTokenizer
16
17 from torchmetrics.text import Perplexity
Load the GPT-2 model and tokenizer
22 model = AutoModelWithLMHead.from_pretrained("gpt2")
23 tokenizer = AutoTokenizer.from_pretrained("gpt2")
/opt/hostedtoolcache/Python/3.12.8/x64/lib/python3.12/site-packages/transformers/models/auto/modeling_auto.py:1833: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.
warnings.warn(
Generate token logits for a sample text
28 sample_text = "The quick brown fox jumps over the lazy dog"
29 sample_input_ids = tokenizer.encode(sample_text, return_tensors="pt")
30
31 with torch.no_grad():
32 sample_outputs = model(sample_input_ids, labels=sample_input_ids)
33 logits = sample_outputs.logits
We can now calculate the perplexity of the logits
38 perplexity = Perplexity()
39 score = perplexity(preds=logits, target=sample_input_ids)
40 print(f"Perplexity, unshifted: {score.item()}")
Perplexity, unshifted: 1929.9822998046875
This perplexity score is suspiciously high. The cause of this is that the model labels need to be shifted to the right by one position. We can fix this by removing the first token from the logits and the last token from the target
45 score = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])
46 print(f"Perplexity, shifted: {score.item()}")
Perplexity, shifted: 227.27783203125
Since the perplexity equates to the exponential of the cross-entropy loss, we can verify the perplexity calculation by comparing it to the loss
51 cross_entropy = score
52 perplexity = sample_outputs.loss.exp()
53 print(torch.allclose(perplexity, cross_entropy))
True
Be aware that sequences are often padded to ensure equal length. In such cases, the padding tokens should be ignored when calculating the perplexity. This can be achieved by specifying the ignore_index argument in the Perplexity metric
58 tokenizer.pad_token_id = tokenizer.eos_token_id
59 sample_input_ids = tokenizer.encode(sample_text, return_tensors="pt", padding="max_length", max_length=20)
60 with torch.no_grad():
61 sample_outputs = model(sample_input_ids, labels=sample_input_ids)
62 logits = sample_outputs.logits
63
64 perplexity = Perplexity(ignore_index=None)
65 score = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])
66 print(f"Perplexity, including padding: {score.item()}")
67
68 perplexity = Perplexity(ignore_index=tokenizer.pad_token_id)
69 score = perplexity(preds=logits[:, :-1], target=sample_input_ids[:, 1:])
70 print(f"Perplexity, ignoring padding: {score.item()}")
Perplexity, including padding: 24400.68359375
Perplexity, ignoring padding: 227.27783203125
Total running time of the script: (0 minutes 3.709 seconds)