Jensen-Shannon Divergence

Module Interface

class torchmetrics.regression.JensenShannonDivergence(log_prob=False, reduction='mean', **kwargs)[source]

Compute the Jensen-Shannon divergence.

DJS(P||Q)=12DKL(P||M)+12DKL(Q||M)

Where P and Q are probability distributions where P usually represents a distribution over data and Q is often a prior or approximation of P. DKL is the KL divergence and M is the average of the two distributions. It should be noted that the Jensen-Shannon divergence is a symmetrical metric i.e. DJS(P||Q)=DJS(Q||P).

As input to forward and update the metric accepts the following input:

  • p (Tensor): a data distribution with shape (N, d)

  • q (Tensor): prior or approximate distribution with shape (N, d)

As output of forward and compute the metric returns the following output:

  • js_divergence (Tensor): A tensor with the Jensen-Shannon divergence

Parameters:
  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1.

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

  • kwargs (Any) – Additional keyword arguments, see Advanced metric settings for more info.

Raises:
  • TypeError – If log_prob is not an bool.

  • ValueError – If reduction is not one of 'mean', 'sum', 'none' or None.

Attention

Half precision is only support on GPU for this metric.

Example

>>>
>>> from torch import tensor
>>> from torchmetrics.regression import JensenShannonDivergence
>>> p = tensor([[0.1, 0.9], [0.2, 0.8], [0.3, 0.7]])
>>> q = tensor([[0.3, 0.7], [0.4, 0.6], [0.5, 0.5]])
>>> js_div = JensenShannonDivergence()
>>> js_div(p, q)
tensor(0.0259)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
  • val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

  • ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:

tuple[Figure, Union[Axes, ndarray]]

Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>>
>>> from torch import randn
>>> # Example plotting a single value
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> metric.update(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1))
>>> fig_, ax_ = metric.plot()
../_images/js_divergence-1.png
>>>
>>> from torch import randn
>>> # Example plotting multiple values
>>> from torchmetrics.regression import KLDivergence
>>> metric = KLDivergence()
>>> values = []
>>> for _ in range(10):
...     values.append(metric(randn(10,3).softmax(dim=-1), randn(10,3).softmax(dim=-1)))
>>> fig, ax = metric.plot(values)
../_images/js_divergence-2.png

Functional Interface

torchmetrics.functional.regression.jensen_shannon_divergence(p, q, log_prob=False, reduction='mean')[source]

Compute Jensen-Shannon divergence.

DJS(P||Q)=12DKL(P||M)+12DKL(Q||M)

Where P and Q are probability distributions where P usually represents a distribution over data and Q is often a prior or approximation of P. DKL is the KL divergence and M is the average of the two distributions. It should be noted that the Jensen-Shannon divergence is a symmetrical metric i.e. DJS(P||Q)=DJS(Q||P).

Parameters:
  • p (Tensor) – data distribution with shape [N, d]

  • q (Tensor) – prior or approximate distribution with shape [N, d]

  • log_prob (bool) – bool indicating if input is log-probabilities or probabilities. If given as probabilities, will normalize to make sure the distributes sum to 1

  • reduction (Literal['mean', 'sum', 'none', None]) –

    Determines how to reduce over the N/batch dimension:

    • 'mean' [default]: Averages score across samples

    • 'sum': Sum score across samples

    • 'none' or None: Returns score per sample

Return type:

Tensor

Example

>>>
>>> from torch import tensor
>>> p = tensor([[0.36, 0.48, 0.16]])
>>> q = tensor([[1/3, 1/3, 1/3]])
>>> jensen_shannon_divergence(p, q)
tensor(0.0225)

You are viewing an outdated version of TorchMetrics Docs

Click here to view the latest version→