from torch import rand, randint from torchmetrics.classification import MultilabelRecallAtFixedPrecision metric = MultilabelRecallAtFixedPrecision(num_labels=3, min_precision=0.5) metric.update(rand(20, 3), randint(2, (20, 3))) fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default