I am trying to save the state_dict of a fine-tuned t5 model from huggingface, however when i found out later that i was unable to properly load from the checkpoint and checked the state_dict in the checkpoint which was actually empty.
class T5FineTuner(pl.LightningModule):
def __init__(self, hparams):
super(T5FineTuner, self).__init__()
self.save_hyperparameters(hparams)
self.model = T5ForConditionalGeneration.from_pretrained(hparams.model_name_or_path)
model = T5FineTuner(args)
early_stop_callback = pl.callbacks.EarlyStopping(
monitor='val_loss',
patience=3,
verbose=False,
mode='min')
checkpoint_callback = pl.callbacks.ModelCheckpoint(
dirpath='model_checkpoints/',
filename=args.model_name_or_path, # your custom filename
save_weights_only=True,
monitor='val_loss',
mode='min',
save_top_k=1,
)
callbacks = [LoggingCallback(logger = logger,task = args.task,file_path = output_num_file,args=args),early_stop_callback,checkpoint_callback]
trainer = trainer = pl.Trainer(strategy = 'fsdp',callbacks = callbacks,enable_checkpointing=True)
Is this due to using fsdp? How to properly save the model using fsdp? I tried to save it manually using torch.save(self.model.state_dict) and it was empty as well.
However when I am not using distributed training, it saves it alright.