Permutation Invariant Training (PIT)¶
Module Interface¶
- class torchmetrics.audio.PermutationInvariantTraining(metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]¶
Calculate Permutation invariant training (PIT).
This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
target
(Tensor
): float tensor with shape(batch_size,num_speakers,...)
As output of forward and compute the metric returns the following output
pesq
(Tensor
): float scalar tensor with average PESQ value over samples
- Parameters:
a metric function accept a batch of target and estimate.
if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors
(batch,)
;if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors
(batch,)
;mode¶ (
Literal
['speaker-wise'
,'permutation-wise'
]) – can be ‘speaker-wise’ or ‘permutation-wise’.eval_func¶ (
Literal
['max'
,'min'
]) – the function to find the best permutation, can be ‘min’ or ‘max’, i.e. the smaller the better or the larger the better.kwargs¶ (
Any
) – Additional keyword arguments for either themetric_func
or distributed communication, see Advanced metric settings for more info.
Example
>>> from torch import randn >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = randn(3, 2, 5) # [batch, spk, time] >>> target = randn(3, 2, 5) # [batch, spk, time] >>> pit = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) tensor(-2.1065)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> metric.update(preds, target) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio >>> preds = torch.randn(3, 2, 5) # [batch, spk, time] >>> target = torch.randn(3, 2, 5) # [batch, spk, time] >>> metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, ... mode="speaker-wise", eval_func="max") >>> values = [ ] >>> for _ in range(10): ... values.append(metric(preds, target)) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.permutation_invariant_training(preds, target, metric_func, mode='speaker-wise', eval_func='max', **kwargs)[source]¶
Calculate Permutation invariant training (PIT).
This metric can evaluate models for speaker independent multi-talker speech separation in a permutation invariant way.
- Parameters:
preds¶ (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
target¶ (
Tensor
) – float tensor with shape(batch_size,num_speakers,...)
a metric function accept a batch of target and estimate. if mode`==’speaker-wise’, then ``metric_func(preds[:, i, …], target[:, j, …])` is called and expected to return a batch of metric tensors
(batch,)
;if mode`==’permutation-wise’, then ``metric_func(preds[:, p, …], target[:, :, …])` is called, where p is one possible permutation, e.g. [0,1] or [1,0] for 2-speaker case, and expected to return a batch of metric tensors
(batch,)
;mode¶ (
Literal
['speaker-wise'
,'permutation-wise'
]) – can be ‘speaker-wise’ or ‘permutation-wise’.eval_func¶ (
Literal
['max'
,'min'
]) – the function to find the best permutation, can be'min'
or'max'
, i.e. the smaller the better or the larger the better.
- Return type:
- Returns:
Tuple of two float tensors. First tensor with shape
(batch,)
contains the best metric value for each sample and second tensor with shape(batch,)
contains the best permutation.
Example
>>> from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio >>> # [batch, spk, time] >>> preds = torch.tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]]) >>> target = torch.tensor([[[ 1.0958, -0.1648, 0.5228], [-0.4100, 1.1942, -0.5103]]]) >>> best_metric, best_perm = permutation_invariant_training( ... preds, target, scale_invariant_signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> best_metric tensor([-5.1091]) >>> best_perm tensor([[0, 1]]) >>> pit_permutate(preds, best_perm) tensor([[[-0.0579, 0.3560, -0.9604], [-0.1719, 0.3205, 0.2951]]])