from torch import rand, randint from torchmetrics.classification import MulticlassPrecisionAtFixedRecall metric = MulticlassPrecisionAtFixedRecall(num_classes=3, min_recall=0.5) metric.update(rand(20, 3).softmax(dim=-1), randint(3, (20,))) fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default