import torch from torchmetrics.audio import SignalDistortionRatio metric = SignalDistortionRatio() metric.update(torch.rand(8000), torch.rand(8000)) fig_, ax_ = metric.plot()