Hi!
Sure thing, in my lightning module I have defined my self.__network
that is an nn.Module
and is used in the forward
and training_step
methods. I’ve wrapped this module using the auto_wrap
function from fairscale.
The custom policy min_num_params
was set to 0 so that my model would be wrapped regardless of number of parameters.
from fairscale.nn import auto_wrap
def __init__(self, ...)
...
self.__network = UNet3D(pretrained_model=pretrained_model)
...
def configure_sharded_model(self):
def auto_wrap_policy(
module: nn.Module,
recurse: bool,
unwrapped_params: int,
module_is_root: bool,
min_num_params: int = 0) -> bool:
return unwrapped_params >= min_num_params
self.__network = auto_wrap(self.__network, auto_wrap_policy=auto_wrap_policy)
Thanks,
Brett