For anyone stumbling on this, the issue is fixable in the Pytorch > =1.10 with the API call set_static_graph
. To implement in PyLightning, one can do:
class CustomDDPPlugin(DDPPlugin):
def configure_ddp(self):
self.pre_configure_ddp()
self._model = self._setup_model(LightningDistributedModule(self.model))
self._register_ddp_hooks()
self._model._set_static_graph() # THIS IS THE MAGIC LINE
Then call as usual
trainer = Trainer(gpus=4, strategy=CustomDDPPlugin())