Ran into something similar (but for mAP metric) and solved it by keeping the metric on the cpu, and then providing a custom dist_sync_fn
which puts the tensor on the cuda device, performs the synchronization, and then moves the result back to the cpu:
def all_gather_on_cuda(tensor: torch.Tensor, *args: T.Any, **kwargs: T.Any) -> T.List[torch.Tensor]:
original_device = tensor.device
return [
_tensor.to(original_device)
for _tensor in gather_all_tensors(tensor.to("cuda"), *args, **kwargs)
]
metric.dist_sync_fn = all_gather_on_cuda # you could alternatively pass this as a keyword argument to the metric's constructor.