from torch import rand, randint from torchmetrics.classification import MultilabelPrecisionAtFixedRecall metric = MultilabelPrecisionAtFixedRecall(num_labels=3, min_recall=0.5) metric.update(rand(20, 3), randint(2, (20, 3))) fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default