Hello, I am new to the community and have been enjoying PL so far! I am working with FSDP models and I am stumped as to what is going on.
My code looks like this:
from functools import partial
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from lightning.pytorch.strategies import FSDPStrategy
policy = partial(size_based_auto_wrap_policy, min_num_params=1_000_000)
strategy = FSDPStrategy(auto_wrap_policy=policy, state_dict_type="sharded")
trainer = Trainer(accelerator="cuda",devices=2,strategy=strategy)
The error message I get when the trainer saves the model during self.strategy.save_checkpoint is:
TypeError: cannot pickle 'module' object
which I am I pretty sure is because my modules gets wrapped like:
net.esm.encoder.layer.5._fsdp_wrapped_module
However, if I just use trainer = Trainer(accelerator="cuda",devices=2,strategy="fsdp")
things work fine.
So my question is: 1. am I correct in the root problem of the pickling? and 2. how do I address this? I would like to use my own FSDP wrapping strategies.