Saving FSDP model with custom FSDPStrategy results in TypeError: cannot pickle 'module' object

Hello, I am new to the community and have been enjoying PL so far! I am working with FSDP models and I am stumped as to what is going on.

My code looks like this:

from functools import partial
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy
from lightning.pytorch.strategies import FSDPStrategy

policy = partial(size_based_auto_wrap_policy, min_num_params=1_000_000)
strategy = FSDPStrategy(auto_wrap_policy=policy, state_dict_type="sharded")
trainer = Trainer(accelerator="cuda",devices=2,strategy=strategy)

The error message I get when the trainer saves the model during self.strategy.save_checkpoint is:
TypeError: cannot pickle 'module' object

which I am I pretty sure is because my modules gets wrapped like:
net.esm.encoder.layer.5._fsdp_wrapped_module

However, if I just use trainer = Trainer(accelerator="cuda",devices=2,strategy="fsdp") things work fine.

So my question is: 1. am I correct in the root problem of the pickling? and 2. how do I address this? I would like to use my own FSDP wrapping strategies.

Looking into it a bit more, I’ve tried working in fabric, which seems to work just fine

import torch
import lightning as L
from functools import partial
from lightning.fabric.strategies import FSDPStrategy
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy

policy = partial(size_based_auto_wrap_policy, min_num_params=1_000)
strategy = FSDPStrategy(auto_wrap_policy=policy, state_dict_type="sharded")

fabric = L.Fabric(devices=2, strategy=strategy)
fabric.launch()

model = torch.nn.Linear(100,100)
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)

model, optimizer = fabric.setup(model, optimizer)

state = {"model": model, "optimizer": optimizer}
fabric.save("saved_model/test.pt", state)

Looking at the code, I’m seeing they use an _unwrap_objects pytorch-lightning/src/lightning/fabric/fabric.py at 6cfc590716cbf52e09033ae11ebee10864ef7589 · Lightning-AI/pytorch-lightning · GitHub call which I don’t see in the lightning strategy code. Perhaps this is the reason for discrepency?