import torch from torchmetrics.retrieval import RetrievalRecall metric = RetrievalRecall() values = [] for _ in range(10): values.append(metric(torch.rand(10,), torch.randint(2, (10,)), indexes=torch.randint(2,(10,)))) fig, ax = metric.plot(values)