import torch from torchmetrics.image import TotalVariation metric = TotalVariation() metric.update(torch.rand(5, 3, 28, 28)) fig_, ax_ = metric.plot()