FullyShardedDataParallel no memory decrease

I figured out the Tensor datatype issue, I was still wrapping my module using the auto_wrap function from fairscale. I moved to using the FSDP class from pytorch and passing the auto_wrap_policy you suggested to that, in addition to setting device_id which was required to fix another error.

The network is now training, using FSDP but there is still no memory decrease when moving to multiple GPUs:

1 GPU:

  • torch.cuda.memory_allocated(): ~ 45GB
  • nvidia-smi: ~75GB

2 GPUs:

  • torch.cuda.memory_allocated(): ~ 45GB per GPU
  • nvidia-smi: ~76GB per GPU

Additionally, I tried removing the def configure_sharded_model function from my lightning module and training, but it didn’t wrap the model at all (perhaps due to number of model params < 100M), and I got similar performance for 1 and 2 GPU runs.

If you don’t have any ideas on this, I’ll go ahead and post the bug report?

Thanks,
Brett