from torch import tensor from torchmetrics.detection.mean_ap import MeanAveragePrecision preds = [dict( boxes=tensor([[258.0, 41.0, 606.0, 285.0]]), scores=tensor([0.536]), labels=tensor([0]), )] target = [dict( boxes=tensor([[214.0, 41.0, 562.0, 285.0]]), labels=tensor([0]), )] metric = MeanAveragePrecision() metric.update(preds, target) fig_, ax_ = metric.plot()