Saving model state dict with fsdp

I am trying to save the state_dict of a fine-tuned t5 model from huggingface, however when i found out later that i was unable to properly load from the checkpoint and checked the state_dict in the checkpoint which was actually empty.

class T5FineTuner(pl.LightningModule):
    def __init__(self, hparams):
        super(T5FineTuner, self).__init__()
        self.save_hyperparameters(hparams)
        
        self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)

model = T5FineTuner(args)
early_stop_callback = pl.callbacks.EarlyStopping(
        monitor='val_loss',
        patience=3,
        verbose=False,
        mode='min')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
        dirpath='model_checkpoints/',
        filename=args.model_name_or_path,  # your custom filename
        save_weights_only=True,
        monitor='val_loss',
        mode='min',
        save_top_k=1,
        )
callbacks = [LoggingCallback(logger = logger,task = args.task,file_path = output_num_file,args=args),early_stop_callback,checkpoint_callback]

trainer = trainer = pl.Trainer(strategy = 'fsdp',callbacks = callbacks,enable_checkpointing=True)

Is this due to using fsdp? How to properly save the model using fsdp? I tried to save it manually using torch.save(self.model.state_dict) and it was empty as well.
However when I am not using distributed training, it saves it alright.

Hi @wj210

I’m posting my answer from GitHub here. The forum here is a better place to discuss this type of question, unless there is a bug or missing feature (then GitHub is better).


You can’t just save a FSDP model with a manual torch.save. You would have to add a lot of boilerplate code from PyTorch to get this right.

I suggest you don’t do that, since Lightning can already do it for you. With Trainer(enable_checkpointing=True), the trainer will already save checkpoints to the logging directory. Furthermore, you can trigger a manual save using trainer.save_checkpiont(...) yourself.

For FSDP to work properly (and it is still experimental and incomplete!), you need to install lightning from source.

pip install https://github.com/Lightning-AI/lightning/archive/refs/heads/master.zip -U

Hi, the issue still persist, despite installing from source and using automatic checkpointing.