Gradient checkpointing + ddp = NaN

Hi, I am quite suspicious of what the checkpoint(...) does, mind share a full example to reproduce? Eventually, maybe open an issue on PL and link it here…