import torch from torchmetrics.audio import PermutationInvariantTraining from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio preds = torch.randn(3, 2, 5) # [batch, spk, time] target = torch.randn(3, 2, 5) # [batch, spk, time] metric = PermutationInvariantTraining(scale_invariant_signal_noise_ratio, mode="speaker-wise", eval_func="max") metric.update(preds, target) fig_, ax_ = metric.plot()