Perplexity

Module Interface

class torchmetrics.text.perplexity.Perplexity(ignore_index=None, **kwargs)[source]

Perplexity measures how well a language model predicts a text sample.

It’s calculated as the average number of bits per word a model needs to represent the sample.

As input to forward and update the metric accepts the following input:

  • preds (Tensor): Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.

  • target (Tensor): Ground truth values with a shape [batch_size, seq_len]

As output of forward and compute the metric returns the following output:

  • perp (Tensor): A tensor with the perplexity score

Parameters:
  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

  • kwargs (dict[str, Any]) – Additional keyword arguments, see Advanced metric settings for more info.

Examples

>>> from torch import rand, randint
>>> from torchmetrics.text import Perplexity
>>> preds = rand(2, 8, 5)
>>> target = randint(5, (2, 8))
>>> target[0, 6:] = -100
>>> perp = Perplexity(ignore_index=-100)
>>> perp(preds, target)
tensor(5.8540)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.text import Perplexity
>>> metric = Perplexity()
>>> metric.update(torch.rand(2, 8, 5), torch.randint(5, (2, 8)))
>>> fig_, ax_ = metric.plot()
../_images/perplexity-1.png
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.text import Perplexity
>>> metric = Perplexity()
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(torch.rand(2, 8, 5), torch.randint(5, (2, 8))))
>>> fig_, ax_ = metric.plot(values)
../_images/perplexity-2.png

Functional Interface

torchmetrics.functional.text.perplexity.perplexity(preds, target, ignore_index=None)[source]

Perplexity measures how well a language model predicts a text sample.

This metric is calculated as the average number of bits per word a model needs to represent the sample.

Parameters:
  • preds (Tensor) – Logits or a unnormalized score assigned to each token in a sequence with shape [batch_size, seq_len, vocab_size], which is the output of a language model. Scores will be normalized internally using softmax.

  • target (Tensor) – Ground truth values with a shape [batch_size, seq_len].

  • ignore_index (Optional[int]) – Integer specifying a target class to ignore. If given, this class index does not contribute to the returned score.

Return type:

Tensor

Returns:

Perplexity value

Examples

>>> from torch import rand, randint
>>> preds = rand(2, 8, 5)
>>> target = randint(5, (2, 8))
>>> target[0, 6:] = -100
>>> perplexity(preds, target, ignore_index=-100)
tensor(5.8540)