import torch from torchmetrics.audio import SignalNoiseRatio metric = SignalNoiseRatio() metric.update(torch.rand(4), torch.rand(4)) fig_, ax_ = metric.plot()