import torch from torchmetrics.audio import ComplexScaleInvariantSignalNoiseRatio metric = ComplexScaleInvariantSignalNoiseRatio() metric.update(torch.rand(1,257,100,2), torch.rand(1,257,100,2)) fig_, ax_ = metric.plot()