FullyShardedDataParallel no memory decrease

Unless you need the fairscale version for specific needs, I suggest switching to the native FSDP strategy from pytorch.

strategy=“fsdp_native”

I also think you need to set the minimum number of parameters to shard layers, like so:

return unwrapped_params >= 10000 # shard modules that have more than 10k params

or like this

import functools
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=10000)