What is the correct way to restore a dataloader state to ensure training resumes from the correct batch after pre-emption / failure

I am working on a shared cluster where my jobs can be pre-empted. I was wondering what the best practices are for restoring dataloader states so that if the run is killed mid-epoch, for example, and the model is restored from checkpoint then the dataloader restores to the same batch index that it left off of.

I will want to support:

  • Shuffling
  • Data parallelism (DDP and/or FSDP)