from torch import randn, randint import torch.nn.functional as F from torchmetrics.classification import BinaryROC preds = F.softmax(randn(20, 2), dim=1) target = randint(2, (20,)) metric = BinaryROC() metric.update(preds[:, 1], target) fig_, ax_ = metric.plot()