KL Divergence

Module Interface

class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]

Compute the KL divergence.

\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]

Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).

As input to forward and update the metric accepts the following input:

  • p (Tensor): a data distribution with shape (N, d)

  • q (Tensor): prior or approximate distribution with shape (N, d)

As output of forward and compute the metric returns the following output:

  • kl_divergence (Tensor): A tensor with the KL divergence

Warning

The input order and naming in metric KLDivergence is set to be deprecated in v1.4 and changed in v1.5. Input argument p will be renamed to target and will be moved to be the second argument of the metric. Input argument q will be renamed to preds and will be moved to the first argument of the metric. Thus, KLDivergence(p, q) will equal KLDivergence(target=q, preds=p) in the future to be consistent with the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.

Parameters:
  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • TypeError – If log_prob is not an bool.

  • ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

Note

Half precision is only support on GPU for this metric

Example

>>> from torch import tensor
>>> from torchmetrics.regression import KLDivergence
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence = KLDivergence()
>>> kl_divergence(p, q)
tensor(0.0853)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

Tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
>>> fig_, ax_ = metric.plot()
../_images/kl_divergence-1.png
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
>>> fig, ax = metric.plot(values)
../_images/kl_divergence-2.png

Functional Interface

torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]

Compute KL divergence.

\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]

Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).

Warning

The input order and naming in metric kl_divergence is set to be deprecated in v1.4 and changed in v1.5. Input argument p will be renamed to target and will be moved to be the second argument of the metric. Input argument q will be renamed to preds and will be moved to the first argument of the metric. Thus, kl_divergence(p, q) will equal kl_divergence(target=q, preds=p) in the future to be consistent with the rest of torchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.

Parameters:
  • p (Tensor) – data distribution with shape [N, d]

  • q (Tensor) – prior or approximate distribution with shape [N, d]

  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

Return type:

Tensor

Example

>>> from torch import tensor
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> kl_divergence(p, q)
tensor(0.0853)