KL Divergence¶
Module Interface¶
- class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]¶
Compute the KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
As input to
forward
andupdate
the metric accepts the following input:p
(Tensor
): a data distribution with shape(N, d)
q
(Tensor
): prior or approximate distribution with shape(N, d)
As output of
forward
andcompute
the metric returns the following output:kl_divergence
(Tensor
): A tensor with the KL divergence
Warning
The input order and naming in metric
KLDivergence
is set to be deprecated in v1.4 and changed in v1.5. Input argumentp
will be renamed totarget
and will be moved to be the second argument of the metric. Input argumentq
will be renamed topreds
and will be moved to the first argument of the metric. Thus,KLDivergence(p, q)
will equalKLDivergence(target=q, preds=p)
in the future to be consistent with the rest oftorchmetrics
. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.- Parameters:
log_prob¶ (
bool
) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.reduction¶ (
Literal
['mean'
,'sum'
,'none'
,None
]) –Determines how to reduce over the
N
/batch dimension:'mean'
[default]: Averages score across samples'sum'
: Sum score across samples'none'
orNone
: Returns score per sample
kwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
TypeError – If
log_prob
is not anbool
.ValueError – If
reduction
is not one of'mean'
,'sum'
,'none'
orNone
.
Note
Half precision is only support on GPU for this metric
Example
>>> from torch import tensor >>> from torchmetrics.regression import KLDivergence >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence = KLDivergence() >>> kl_divergence(p, q) tensor(0.0853)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]¶
Compute KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
Warning
The input order and naming in metric
kl_divergence
is set to be deprecated in v1.4 and changed in v1.5. Input argumentp
will be renamed totarget
and will be moved to be the second argument of the metric. Input argumentq
will be renamed topreds
and will be moved to the first argument of the metric. Thus,kl_divergence(p, q)
will equalkl_divergence(target=q, preds=p)
in the future to be consistent with the rest oftorchmetrics
. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.- Parameters:
q¶ (
Tensor
) – prior or approximate distribution with shape[N, d]
log_prob¶ (
bool
) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1reduction¶ (
Literal
['mean'
,'sum'
,'none'
,None
]) –Determines how to reduce over the
N
/batch dimension:'mean'
[default]: Averages score across samples'sum'
: Sum score across samples'none'
orNone
: Returns score per sample
- Return type:
Example
>>> from torch import tensor >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence(p, q) tensor(0.0853)