Gradient checkpointing + ddp = NaN

For anyone stumbling on this, the issue is fixable in the Pytorch > =1.10 with the API call set_static_graph. To implement in PyLightning, one can do:

class CustomDDPPlugin(DDPPlugin):
    def configure_ddp(self):
        self.pre_configure_ddp()
        self._model = self._setup_model(LightningDistributedModule(self.model))
        self._register_ddp_hooks()
        self._model._set_static_graph() # THIS IS THE MAGIC LINE

Then call as usual

trainer = Trainer(gpus=4, strategy=CustomDDPPlugin())