FSDP sharded checkpointing slower than any other method

Hi folks!

I’ve been recently looking into methods of speeding up the checkpointing part of my large scale training workloads. I was very happy to find out that PyTorch Lightning FSDP implementation already supports sharded checkpointing mechanism, which in my expectation would lead not only to reduced storage utilization, but would as well significantly reduce the save / load time with the cost of single all-gather operation. I decided to pick a significantly large model (barely fitting on H100) and check following scenarios:

  1. Plain PyTorch DDP
  2. Data-parallel-only FSDP (NO_SHARD strategy)
  3. Data-parallel-only FSDP with sharded checkpointing (NO_SHARD strategy with state_dict_type: "sharded")
    I decided to measure time spent in saving the checkpoint by measuring the difference between on_save_checkpoint() callback and on_train_batch_start() callback of following iteration.

I was pretty surprised that at the experiments I’ve run showed that regardless of the scale DDP yielded best performance, and sharded checkpointing was significantly slower than any other method:

Method #nondes Total fit time save_ckpt_time [s]
mean std min max
fsdp 1 540.49 38.53 1.58 36.98 42.83
2 674.81 49.50 3.91 43.77 56.13
4 557.34 37.98 1.69 34.82 40.45
8 568.06 37.99 1.39 35.44 40.35
fsdp_sharded 1 633.38 47.76 2.46 43.19 50.74
2 687.81 51.70 4.46 46.70 60.57
4 627.69 46.65 3.25 43.09 50.86
8 711.28 50.79 3.55 44.06 55.98
ddp 1 564.02 39.90 2.98 37.96 48.08
2 535.85 34.36 6.25 31.36 51.99
4 NCCL OOM - - - -
8 NCCL OOM - - - -

Did I measure the time spent in saving checkpoint and used shared checkpointing correctly? If yes, how come sharded checkpointing yielded worst performance even of 8th nodes (so while saving 1/64th of model on each of the devices)?

Thanks!

Replying my own post, but the root cause for poor performance seems to be that due to NO_SHARD strategy FSDP saved one “big” sharded checkpoint and multiple empty ones:

$ du -hs *

23G __0_0.distcp

0 __10_0.distcp

0 __11_0.distcp

0 __12_0.distcp

...

...

0 __9_0.distcp

7.6G meta.pt