import torch from torchmetrics.audio import ScaleInvariantSignalDistortionRatio target = torch.randn(5) preds = torch.randn(5) metric = ScaleInvariantSignalDistortionRatio() metric.update(preds, target) fig_, ax_ = metric.plot()