Hi folks!
I’ve been recently looking into methods of speeding up the checkpointing part of my large scale training workloads. I was very happy to find out that PyTorch Lightning FSDP implementation already supports sharded checkpointing mechanism, which in my expectation would lead not only to reduced storage utilization, but would as well significantly reduce the save / load time with the cost of single all-gather operation. I decided to pick a significantly large model (barely fitting on H100) and check following scenarios:
- Plain PyTorch DDP
- Data-parallel-only FSDP (
NO_SHARD
strategy) - Data-parallel-only FSDP with sharded checkpointing (
NO_SHARD
strategy withstate_dict_type: "sharded"
)
I decided to measure time spent in saving the checkpoint by measuring the difference betweenon_save_checkpoint()
callback andon_train_batch_start()
callback of following iteration.
I was pretty surprised that at the experiments I’ve run showed that regardless of the scale DDP yielded best performance, and sharded checkpointing was significantly slower than any other method:
Method | #nondes | Total fit time | save_ckpt_time [s] | |||
---|---|---|---|---|---|---|
mean | std | min | max | |||
fsdp | 1 | 540.49 | 38.53 | 1.58 | 36.98 | 42.83 |
2 | 674.81 | 49.50 | 3.91 | 43.77 | 56.13 | |
4 | 557.34 | 37.98 | 1.69 | 34.82 | 40.45 | |
8 | 568.06 | 37.99 | 1.39 | 35.44 | 40.35 | |
fsdp_sharded | 1 | 633.38 | 47.76 | 2.46 | 43.19 | 50.74 |
2 | 687.81 | 51.70 | 4.46 | 46.70 | 60.57 | |
4 | 627.69 | 46.65 | 3.25 | 43.09 | 50.86 | |
8 | 711.28 | 50.79 | 3.55 | 44.06 | 55.98 | |
ddp | 1 | 564.02 | 39.90 | 2.98 | 37.96 | 48.08 |
2 | 535.85 | 34.36 | 6.25 | 31.36 | 51.99 | |
4 | NCCL OOM | - | - | - | - | |
8 | NCCL OOM | - | - | - | - |
Did I measure the time spent in saving checkpoint and used shared checkpointing correctly? If yes, how come sharded checkpointing yielded worst performance even of 8th nodes (so while saving 1/64th of model on each of the devices)?
Thanks!