Perplexity¶
Module Interface¶
- class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]¶
Perplexity measures how well a language model predicts a text sample.
It’s calculated as the average number of bits per word a model needs to represent the sample.
As input to
forward
andupdate
the metric accepts the following input:preds
(Tensor
): Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.target
(Tensor
): Ground truth values with a shape [batch_size, seq_len]
As output of
forward
andcompute
the metric returns the following output:perp
(Tensor
): A tensor with the perplexity score
- Parameters:
Examples
>>> from torch import rand, randint >>> from torchmetrics.text import Perplexity >>> preds = rand(2, 8, 5) >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> perp = Perplexity(ignore_index=-100) >>> perp(preds, target) tensor(5.8540)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.text import Perplexity >>> metric = Perplexity() >>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8))) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.text import Perplexity >>> metric = Perplexity() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]¶
Perplexity measures how well a language model predicts a text sample.
This metric is calculated as the average number of bits per word a model needs to represent the sample.
- Parameters:
preds¶ (
Tensor
) – Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.target¶ (
Tensor
) – Ground truth values with a shape [batch_size, seq_len].ignore_index¶ (
Optional
[int
]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.
- Return type:
- Returns:
Perplexity value
Examples
>>> from torch import rand, randint >>> preds = rand(2, 8, 5) >>> target = randint(5, (2, 8)) >>> target[0, 6:] = -100 >>> perplexity(preds, target, ignore_index=-100) tensor(5.8540)