import torch from torchmetrics.audio import ScaleInvariantSignalNoiseRatio metric = ScaleInvariantSignalNoiseRatio() metric.update(torch.rand(4), torch.rand(4)) fig_, ax_ = metric.plot()