FullyShardedDataParallel no memory decrease

Hi!

Sure thing, in my lightning module I have defined my self.__network that is an nn.Module and is used in the forward and training_step methods. I’ve wrapped this module using the auto_wrap function from fairscale.

The custom policy min_num_params was set to 0 so that my model would be wrapped regardless of number of parameters.

from fairscale.nn import auto_wrap

def __init__(self, ...)
    ...
    self.__network = UNet3D(pretrained_model=pretrained_model)
    ...

def configure_sharded_model(self):
    def auto_wrap_policy(
        module: nn.Module,
        recurse: bool,
        unwrapped_params: int,
        module_is_root: bool,
        min_num_params: int = 0) -> bool:
            return unwrapped_params >= min_num_params
    self.__network = auto_wrap(self.__network, auto_wrap_policy=auto_wrap_policy)

Thanks,
Brett