How can we skip a step with NaN loss in the training_step when using Distributed Data Parallel (DDP)?

How can we skip a step with NaN loss in the training_step when using Distributed Data Parallel (DDP) across multiple machines and multiple GPUs? Returning None in the training_step can skip the step, but it doesn’t work for multi-machine multi-GPU scenarios. What are the possible solutions?

There is no efficient solution for this. The inefficient way is to gather all loss values like so (pseudo code!):

losses = self.all_gather(loss)
if any(isnan(loss) for loss in losses):
    return None  # skip training step

The gathering from all processes is necessary so that the training loop doesn’t fall out of sync if only one process gets a nan-skip.

However, the best solution is to fix the cause of the NaN in the network, so that this doesn’t happen in the first place. You can enable Trainer(detect_anomaly=True) for debugging, maybe it helps you find the instability.