KL Divergence¶
Module Interface¶
- class torchmetrics.KLDivergence(log_prob=False, reduction='mean', **kwargs)[source]¶
Compute the KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
As input to
forwardandupdatethe metric accepts the following input:p(Tensor): a data distribution with shape(N, d)q(Tensor): prior or approximate distribution with shape(N, d)
As output of
forwardandcomputethe metric returns the following output:kl_divergence(Tensor): A tensor with the KL divergence
Warning
The input order and naming in metric
KLDivergenceis set to be deprecated in v1.4 and changed in v1.5. Input argumentpwill be renamed totargetand will be moved to be the second argument of the metric. Input argumentqwill be renamed topredsand will be moved to the first argument of the metric. Thus,KLDivergence(p, q)will equalKLDivergence(target=q, preds=p)in the future to be consistent with the rest oftorchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.- Parameters:
log_prob¶ (
bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.reduction¶ (
Literal['mean','sum','none',None]) –Determines how to reduce over the
N/batch dimension:'mean'[default]: Averages score across samples'sum': Sum score across samples'none'orNone: Returns score per sample
kwargs¶ (
Any) – Additional keyword arguments, see Advanced metric settings for more info.
- Raises:
TypeError – If
log_probis not anbool.ValueError – If
reductionis not one of'mean','sum','none'orNone.
Note
Half precision is only support on GPU for this metric
Example
>>> from torch import tensor >>> from torchmetrics.regression import KLDivergence >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence = KLDivergence() >>> kl_divergence(p, q) tensor(0.0853)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union[Tensor,Sequence[Tensor],None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> from torch import randn >>> # Example plotting a single value >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)) >>> fig_, ax_ = metric.plot()
>>> from torch import randn >>> # Example plotting multiple values >>> from torchmetrics.regression import KLDivergence >>> metric = KLDivergence() >>> values = [] >>> for _ in range(10): ... values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))) >>> fig, ax = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.kl_divergence(p, q, log_prob=False, reduction='mean')[source]¶
Compute KL divergence.
\[D_{KL}(P||Q) = \sum_{x\in\mathcal{X}} P(x) \log\frac{P(x)}{Q{x}}\]Where \(P\) and \(Q\) are probability distributions where \(P\) usually represents a distribution over data and \(Q\) is often a prior or approximation of \(P\). It should be noted that the KL divergence is a non-symmetrical metric i.e. \(D_{KL}(P||Q) \neq D_{KL}(Q||P)\).
Warning
The input order and naming in metric
kl_divergenceis set to be deprecated in v1.4 and changed in v1.5. Input argumentpwill be renamed totargetand will be moved to be the second argument of the metric. Input argumentqwill be renamed topredsand will be moved to the first argument of the metric. Thus,kl_divergence(p, q)will equalkl_divergence(target=q, preds=p)in the future to be consistent with the rest oftorchmetrics. From v1.4 the two new arguments will be added as keyword arguments and from v1.5 the two old arguments will be removed.- Parameters:
q¶ (
Tensor) – prior or approximate distribution with shape[N, d]log_prob¶ (
bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1reduction¶ (
Literal['mean','sum','none',None]) –Determines how to reduce over the
N/batch dimension:'mean'[default]: Averages score across samples'sum': Sum score across samples'none'orNone: Returns score per sample
- Return type:
Example
>>> from torch import tensor >>> p = tensor([[0.36, 0.48, 0.16]]) >>> q = tensor([[1/3, 1/3, 1/3]]) >>> kl_divergence(p, q) tensor(0.0853)