Signal to Distortion Ratio (SDR)¶
Module Interface¶
- class torchmetrics.audio.SignalDistortionRatio(use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None, **kwargs)[source]¶
Calculate Signal to Distortion Ratio (SDR) metric.
See SDR ref1 and SDR ref2 for details on the metric.
As input to
forward
andupdate
the metric accepts the following inputpreds
(Tensor
): float tensor with shape(...,time)
target
(Tensor
): float tensor with shape(...,time)
As output of forward and compute the metric returns the following output
sdr
(Tensor
): float scalar tensor with average SDR value over samples
- Parameters:
use_cg_iter¶ (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zerokwargs¶ (
Any
) – Additional keyword arguments, see Advanced metric settings for more info.
Example
>>> from torch import randn >>> from torchmetrics.audio import SignalDistortionRatio >>> preds = randn(8000) >>> target = randn(8000) >>> sdr = SignalDistortionRatio() >>> sdr(preds, target) tensor(-11.9930) >>> # use with pit >>> from torchmetrics.audio import PermutationInvariantTraining >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = randn(4, 2, 8000) # [batch, spk, time] >>> target = randn(4, 2, 8000) >>> pit = PermutationInvariantTraining(signal_distortion_ratio, ... mode="speaker-wise", eval_func="max") >>> pit(preds, target) tensor(-11.7277)
- plot(val=None, ax=None)[source]¶
Plot a single or multiple values from the metric.
- Parameters:
val¶ (
Union
[Tensor
,Sequence
[Tensor
],None
]) – Either a single result from calling metric.forward or metric.compute or a list of these results. If no value is provided, will automatically call metric.compute and plot that result.ax¶ (
Optional
[Axes
]) – An matplotlib axis object. If provided will add plot to that axis
- Return type:
- Returns:
Figure and Axes object
- Raises:
ModuleNotFoundError – If matplotlib is not installed
>>> # Example plotting a single value >>> import torch >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> metric.update(torch.rand(8000), torch.rand(8000)) >>> fig_, ax_ = metric.plot()
>>> # Example plotting multiple values >>> import torch >>> from torchmetrics.audio import SignalDistortionRatio >>> metric = SignalDistortionRatio() >>> values = [ ] >>> for _ in range(10): ... values.append(metric(torch.rand(8000), torch.rand(8000))) >>> fig_, ax_ = metric.plot(values)
Functional Interface¶
- torchmetrics.functional.audio.signal_distortion_ratio(preds, target, use_cg_iter=None, filter_length=512, zero_mean=False, load_diag=None)[source]¶
Calculate Signal to Distortion Ratio (SDR) metric. See SDR ref1 and SDR ref2 for details on the metric.
- Parameters:
use_cg_iter¶ (
Optional
[int
]) – If provided, conjugate gradient descent is used to solve for the distortion filter coefficients instead of direct Gaussian elimination, which requires thatfast-bss-eval
is installed and pytorch version >= 1.8. This can speed up the computation of the metrics in case the filters are long. Using a value of 10 here has been shown to provide good accuracy in most cases and is sufficient when using this loss to train neural separation networks.filter_length¶ (
int
) – The length of the distortion filter allowedzero_mean¶ (
bool
) – When set to True, the mean of all signals is subtracted prior to computation of the metricsload_diag¶ (
Optional
[float
]) – If provided, this small value is added to the diagonal coefficients of the system metrics when solving for the filter coefficients. This can help stabilize the metric in the case where some reference signals may sometimes be zero
- Return type:
- Returns:
Float tensor with shape
(...,)
of SDR values per sample- Raises:
RuntimeError – If
preds
andtarget
does not have the same shape
Example
>>> from torch import randn >>> from torchmetrics.functional.audio import signal_distortion_ratio >>> preds = randn(8000) >>> target = randn(8000) >>> signal_distortion_ratio(preds, target) tensor(-11.9930) >>> # use with permutation_invariant_training >>> from torchmetrics.functional.audio import permutation_invariant_training >>> preds = randn(4, 2, 8000) # [batch, spk, time] >>> target = randn(4, 2, 8000) >>> best_metric, best_perm = permutation_invariant_training(preds, target, signal_distortion_ratio) >>> best_metric tensor([-11.7748, -11.7948, -11.7160, -11.6254]) >>> best_perm tensor([[1, 0], [1, 0], [1, 0], [0, 1]])