Gradient checkpointing + ddp = NaN

I will try to reduce the example and then post it.