FullyShardedDataParallel no memory decrease

Hi there,

I’m trying to perform fully-sharded data parallel training across 4 A100s to train a 3D U-Net for segmentation of medical images. I’m hoping to decrease the GPU memory footprint during training, so that I can train on larger (higher-resolution) CT scans.


  • pytorch-lightning v1.8.3
  • torch v1.9.0+cu111

I’ve enabled FSDP using strategy='fsdp' flag with precision=16. I’ve also used the auto_wrap function in my lightning module (with a custom policy that reduces min number of params to wrap as my whole model contains 90m params only). This seems to be working as the model size per-GPU is shown as 22.6M. Not sure why the estimated model size is only 45.146M though…

However, I’m not seeing any decrease in GPU memory footprint between training with a single GPU or training with 4 GPUs. Couldn’t attached the screenshots as I’m a new forum user.

Any suggestions?



The model summary was not yet updated to be compatible with sharded models like FSDP. The size there can be misleading. There will be a separate column “params per device” similar to the summary that shows when using deepspeed.

Can you show us how you have wrapped and applied the policy?


Sure thing, in my lightning module I have defined my self.__network that is an nn.Module and is used in the forward and training_step methods. I’ve wrapped this module using the auto_wrap function from fairscale.

The custom policy min_num_params was set to 0 so that my model would be wrapped regardless of number of parameters.

from fairscale.nn import auto_wrap

def __init__(self, ...)
    self.__network = UNet3D(pretrained_model=pretrained_model)

def configure_sharded_model(self):
    def auto_wrap_policy(
        module: nn.Module,
        recurse: bool,
        unwrapped_params: int,
        module_is_root: bool,
        min_num_params: int = 0) -> bool:
            return unwrapped_params >= min_num_params
    self.__network = auto_wrap(self.__network, auto_wrap_policy=auto_wrap_policy)


Unless you need the fairscale version for specific needs, I suggest switching to the native FSDP strategy from pytorch.


I also think you need to set the minimum number of parameters to shard layers, like so:

return unwrapped_params >= 10000 # shard modules that have more than 10k params

or like this

import functools
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=10000)

Thanks for the suggestion.

I needed to bump torch version (now v1.13.0) to access native FSDP. I’m now seeing the following error during the forward pass:

RuntimeError: Input type (torch.cuda.HalfTensor) and weight type (torch.FloatTensor) should be the same

It seems like the model weights haven’t been moved to the GPU or converted to 16-bit precision as expected by the precision=16 flag. Any thoughts on this?


This is likely because you need to your configure optimizers from

def configure_optimizers(self):
    return YourOptimizer(self.parameters(), ...)


def configure_optimizers(self):
    return YourOptimizer(self.trainer.model.parameters(), ...)

This is a quirk that is currently necessary in Lightning, but as FSDP matures in the future, this won’t be necessary anymore.

If this doesn’t solve your issue, please feel free to post a GitHub bug report and we can take a closer look.

Thanks but I’m already configuring the optimiser like that:

def configure_optimizers(self):
    return SGD(self.trainer.model.parameters(), lr=1e-3, momentum=0.9)

I’ll post that GitHub bug report. Thanks for your help!


I figured out the Tensor datatype issue, I was still wrapping my module using the auto_wrap function from fairscale. I moved to using the FSDP class from pytorch and passing the auto_wrap_policy you suggested to that, in addition to setting device_id which was required to fix another error.

The network is now training, using FSDP but there is still no memory decrease when moving to multiple GPUs:

1 GPU:

  • torch.cuda.memory_allocated(): ~ 45GB
  • nvidia-smi: ~75GB

2 GPUs:

  • torch.cuda.memory_allocated(): ~ 45GB per GPU
  • nvidia-smi: ~76GB per GPU

Additionally, I tried removing the def configure_sharded_model function from my lightning module and training, but it didn’t wrap the model at all (perhaps due to number of model params < 100M), and I got similar performance for 1 and 2 GPU runs.

If you don’t have any ideas on this, I’ll go ahead and post the bug report?