from torch import rand, randint from torchmetrics.classification import BinaryPrecisionAtFixedRecall metric = BinaryPrecisionAtFixedRecall(min_recall=0.5) metric.update(rand(10), randint(2,(10,))) fig_, ax_ = metric.plot() # the returned plot only shows the maximum recall value by default