• Docs >
• Permutation Invariant Training (PIT)
Shortcuts

# Permutation Invariant Training (PIT)¶

## Module Interface¶

class torchmetrics.audio.PermutationInvariantTraining(metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]

Calculate Permutation invariant training (PIT).

This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.

As input to forward and update the metric accepts the following input

• preds (Tensor): float tensor with shape (batch_size,num_speakers,...)

• target (Tensor): float tensor with shape (batch_size,num_speakers,...)

As output of forward and compute the metric returns the following output

• pesq (Tensor): float scalar tensor with average PESQ value over samples

Parameters:
• metric_func (Callable) –

a metric function accept a batch of target and estimate.

if mode==’speaker-wise’, then metric_func(preds[:, i, …], target[:, j, …]) is called and expected to return a batch of metric tensors (batch,);

if mode==’permutation-wise’, then metric_func(preds[:, p, …], target[:, :, …]) is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors (batch,);

• mode (Literal['speaker-wise', 'permutation-wise']) – can be ‘speaker-wise’ or ‘permutation-wise’.

• eval_func (Literal['max', 'min']) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.

Example

>>> from torch import randn
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = randn(3, 2, 5) # [batch, spk, time]
>>> target = randn(3, 2, 5) # [batch, spk, time]
>>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> pit(preds, target)
tensor(-2.1065)
plot(val=None, ax=None)[source]

Plot a single or multiple values from the metric.

Parameters:
• val (Union[Tensor, Sequence[Tensor], None]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.

• ax (Optional[Axes]) – An matplotlib axis object. If provided will add plot to that axis

Return type:
Returns:

Figure and Axes object

Raises:

ModuleNotFoundError – If matplotlib is not installed

>>> # Example plotting a single value
>>> import torch
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> metric.update(preds, target)
>>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values
>>> import torch
>>> from torchmetrics.audio import PermutationInvariantTraining
>>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio
>>> preds = torch.randn(3, 2, 5) # [batch, spk, time]
>>> target = torch.randn(3, 2, 5) # [batch, spk, time]
>>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio,
...     mode="speaker-wise", eval_func="max")
>>> values = [ ]
>>> for _ in range(10):
...     values.append(metric(preds, target))
>>> fig_, ax_ = metric.plot(values)

## Functional Interface¶

torchmetrics.functional.audio.permutation_invariant_training(preds, target, metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]

Calculate Permutation invariant training (PIT).

This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.

Parameters:
• preds (Tensor) – float tensor with shape (batch_size,num_speakers,...)

• target (Tensor) – float tensor with shape (batch_size,num_speakers,...)

• metric_func (Callable) –

a metric function accept a batch of target and estimate. if mode==’speaker-wise’, then metric_func(preds[:, i, …], target[:, j, …]) is called and expected to return a batch of metric tensors (batch,);

if mode==’permutation-wise’, then metric_func(preds[:, p, …], target[:, :, …]) is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors (batch,);

• mode (Literal['speaker-wise', 'permutation-wise']) – can be ‘speaker-wise’ or ‘permutation-wise’.

• eval_func (Literal['max', 'min']) – the function to find the best permutation, can be 'min' or 'max', i.e. the smaller the better or the larger the better.

• kwargs (Any) – Additional args for metric_func

Return type:
Returns:

Tuple of two float tensors. First tensor with shape (batch,) contains the best metric value for each sample and second tensor with shape (batch,) contains the best permutation.

Example

>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio
>>> # [batch, spk, time]
>>> preds = torch.tensor([[[-0.0579,  0.3560, -0.9604], [-0.1719,  0.3205,  0.2951]]])
>>> target = torch.tensor([[[ 1.0958, -0.1648,  0.5228], [-0.4100,  1.1942, -0.5103]]])
>>> best_metric, best_perm = permutation_invariant_training(
...     preds, target, scale_invariant_signal_distortion_ratio,
...     mode="speaker-wise", eval_func="max")
>>> best_metric
tensor([-5.1091])
>>> best_perm
tensor([[0, 1]])
>>> pit_permutate(preds, best_perm)
tensor([[[-0.0579,  0.3560, -0.9604],
[-0.1719,  0.3205,  0.2951]]])