Compute Precision Recall Curve without OOM

Ran into something similar (but for mAP metric) and solved it by keeping the metric on the cpu, and then providing a custom dist_sync_fn which puts the tensor on the cuda device, performs the synchronization, and then moves the result back to the cpu:

def all_gather_on_cuda(tensor: torch.Tensor, *args: T.Any, **kwargs: T.Any) -> T.List[torch.Tensor]:
    original_device = tensor.device
    return [
        _tensor.to(original_device)
        for _tensor in gather_all_tensors(tensor.to("cuda"), *args, **kwargs)
    ]
metric.dist_sync_fn = all_gather_on_cuda    # you could alternatively pass this as a keyword argument to the metric's constructor.