How can we skip a step with NaN loss in the training_step when using Distributed Data Parallel (DDP) across multiple machines and multiple GPUs? Returning None in the training_step can skip the step, but it doesn’t work for multi-machine multi-GPU scenarios. What are the possible solutions?
How can we skip a step with NaN loss in the training_step when using Distributed Data Parallel (DDP)?
There is no efficient solution for this. The inefficient way is to gather all loss values like so (pseudo code!):
losses = self.all_gather(loss) if any(isnan(loss) for loss in losses): return None # skip training step
The gathering from all processes is necessary so that the training loop doesn’t fall out of sync if only one process gets a nan-skip.
However, the best solution is to fix the cause of the NaN in the network, so that this doesn’t happen in the first place. You can enable
Trainer(detect_anomaly=True) for debugging, maybe it helps you find the instability.