I have a model, that uses gradient checkpointing and ddp. It works fine, when I train it on a single gpu. It also works fine if I turn off checkpointing. However with multiple GPUs loss initially looks innocent, but then suddenly becomes NaN:
checkpointing
no checkpointing
gpus = 1
works
works
gpus = 4
fails
works
The only part of the model that uses checkpointing is:
class MergeLayer(nn.Module):
...
def apply_forward(self, inputs):
x = torch.cat(inputs, 1)
assert x.size(1) == self.in_channels
x = F.leaky_relu(x)
x = self.conv(x)
x = F.leaky_relu(x)
assert x.size(1) == self.out_channels
return x
def _apply_forward_splat(self, *inputs):
# checkpointing does not like to consume list
return self.apply_forward(inputs)
def forward(self, inputs):
assert total_channels(inputs) == self.in_channels
if self.save_memory and any_requires_grad(inputs):
x = checkpoint(self._apply_forward_splat, *inputs)
else:
x = self.apply_forward(inputs)
assert x.size(1) == self.out_channels
return x
Hi, I am quite suspicious of what the checkpoint(...) does, mind share a full example to reproduce? Eventually, maybe open an issue on PL and link it here…
I’ve observed training freeze behavior when using DDP, gradient checkpointing and SyncBatchNorm. By following this solution: Training gets stuck when using SyncBN · Issue #105 · NVIDIA/apex · GitHub, the training won’t freeze but the loss after first iteration (sometimes after several iterations) becomes NaN.
I am wondering have you solved this issue? Or is there anything interesting you’ve discovered? Thanks!
For anyone stumbling on this, the issue is fixable in the Pytorch > =1.10 with the API call set_static_graph. To implement in PyLightning, one can do:
class CustomDDPPlugin(DDPPlugin):
def configure_ddp(self):
self.pre_configure_ddp()
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
self._model._set_static_graph() # THIS IS THE MAGIC LINE
Thank you @Daniel_Murnane for the solution it saved me! I am doing activation checkpointing (aka gradient checkpointing) Following on, here is a quick patch for lightning 2.3.3
import lightning as lit
from lightning.pytorch.strategies import DDPStrategy
class CustomDDPStrategy(DDPStrategy):
def configure_ddp(self) -> None:
assert isinstance(self.model, l.LightningModule)
self.model = self._setup_model(self.model)
self._register_ddp_hooks()
self._model._set_static_graph() # THIS IS THE MAGIC LINE