Hi! I have followed the guide here to train a BART model from HF with FSDP. Here is how the module looks like:
class BartModule(L.LightningModule):
def __init__(self, config: "DictConfig"):
super().__init__()
self.save_hyperparameters(config)
self.config = config
def configure_model(self) -> None:
model_config = BartConfig.from_dict(self.config.model)
self.model = BartForConditionalGeneration(model_config)
try:
torch.compile(self.model)
logger.info("Model compiled successfully!")
except Exception as e:
logger.error(f"Error compiling model: {e}\nSkipping compilation...")
def forward(self, **inputs):
return self.model(**inputs)
def training_step(self, batch, batch_idx):
...
The model trains, and produces checkpoints in the specified directory. However, when I try to load one using BartModule.load_from_checkpoint(path_to_checkpoint)
, I get the following error:
Runtime error details
RuntimeError: Error(s) in loading state_dict for BartModule:
Unexpected key(s) in state_dict: “model.final_logits_bias”, “model.model.shared.weight”, “model.model.encoder.embed_tokens.weight”, “model.model.encoder.embed_positions.weight”, “model.model.encoder.layers.0.self_attn.k_proj.weight”, “model.model.encoder.layers.0.self_attn.k_proj.bias”, “model.model.encoder.layers.0.self_attn.v_proj.weight”, “model.model.encoder.layers.0.self_attn.v_proj.bias”, “model.model.encoder.layers.0.self_attn.q_proj.weight”, “model.model.encoder.layers.0.self_attn.q_proj.bias”, “model.model.encoder.layers.0.self_attn.out_proj.weight”, “model.model.encoder.layers.0.self_attn.out_proj.bias”, “model.model.encoder.layers.0.self_attn_layer_norm.weight”, “model.model.encoder.layers.0.self_attn_layer_norm.bias”, “model.model.encoder.layers.0.fc1.weight”, “model.model.encoder.layers.0.fc1.bias”, “model.model.encoder.layers.0.fc2.weight”, “model.model.encoder.layers.0.fc2.bias”, “model.model.encoder.layers.0.final_layer_norm.weight”, “model.model.encoder.layers.0.final_layer_norm.bias”, “model.model.encoder.layers.1.self_attn.k_proj.weight”, “model.model.encoder.layers.1.self_attn.k_proj.bias”, “model.model.encoder.layers.1.self_attn.v_proj.weight”, “model.model.encoder.layers.1.self_attn.v_proj.bias”, “model.model.encoder.layers.1.self_attn.q_proj.weight”, “model.model.encoder.layers.1.self_attn.q_proj.bias”, “model.model.encoder.layers.1.self_attn.out_proj.weight”, “model.model.encoder.layers.1.self_attn.out_proj.bias”, “model.model.encoder.layers.1.self_attn_layer_norm.weight”, “model.model.encoder.layers.1.self_attn_layer_norm.bias”, “model.model.encoder.layers.1.fc1.weight”, “model.model.encoder.layers.1.fc1.bias”, “model.model.encoder.layers.1.fc2.weight”, “model.model.encoder.layers.1.fc2.bias”, “model.model.encoder.layers.1.final_layer_norm.weight”, “model.model.encoder.layers.1.final_layer_norm.bias”, “model.model.encoder.layernorm_embedding.weight”, “model.model.encoder.layernorm_embedding.bias”, “model.model.decoder.embed_tokens.weight”, “model.model.decoder.embed_positions.weight”, “model.model.decoder.layers.0.self_attn.k_proj.weight”, “model.model.decoder.layers.0.self_attn.k_proj.bias”, “model.model.decoder.layers.0.self_attn.v_proj.weight”, “model.model.decoder.layers.0.self_attn.v_proj.bias”, “model.model.decoder.layers.0.self_attn.q_proj.weight”, “model.model.decoder.layers.0.self_attn.q_proj.bias”, “model.model.decoder.layers.0.self_attn.out_proj.weight”, “model.model.decoder.layers.0.self_attn.out_proj.bias”, “model.model.decoder.layers.0.self_attn_layer_norm.weight”, “model.model.decoder.layers.0.self_attn_layer_norm.bias”, “model.model.decoder.layers.0.encoder_attn.k_proj.weight”, “model.model.decoder.layers.0.encoder_attn.k_proj.bias”, “model.model.decoder.layers.0.encoder_attn.v_proj.weight”, “model.model.decoder.layers.0.encoder_attn.v_proj.bias”, “model.model.decoder.layers.0.encoder_attn.q_proj.weight”, “model.model.decoder.layers.0.encoder_attn.q_proj.bias”, “model.model.decoder.layers.0.encoder_attn.out_proj.weight”, “model.model.decoder.layers.0.encoder_attn.out_proj.bias”, “model.model.decoder.layers.0.encoder_attn_layer_norm.weight”, “model.model.decoder.layers.0.encoder_attn_layer_norm.bias”, “model.model.decoder.layers.0.fc1.weight”, “model.model.decoder.layers.0.fc1.bias”, “model.model.decoder.layers.0.fc2.weight”, “model.model.decoder.layers.0.fc2.bias”, “model.model.decoder.layers.0.final_layer_norm.weight”, “model.model.decoder.layers.0.final_layer_norm.bias”, “model.model.decoder.layers.1.self_attn.k_proj.weight”, “model.model.decoder.layers.1.self_attn.k_proj.bias”, “model.model.decoder.layers.1.self_attn.v_proj.weight”, “model.model.decoder.layers.1.self_attn.v_proj.bias”, “model.model.decoder.layers.1.self_attn.q_proj.weight”, “model.model.decoder.layers.1.self_attn.q_proj.bias”, “model.model.decoder.layers.1.self_attn.out_proj.weight”, “model.model.decoder.layers.1.self_attn.out_proj.bias”, “model.model.decoder.layers.1.self_attn_layer_norm.weight”, “model.model.decoder.layers.1.self_attn_layer_norm.bias”, “model.model.decoder.layers.1.encoder_attn.k_proj.weight”, “model.model.decoder.layers.1.encoder_attn.k_proj.bias”, “model.model.decoder.layers.1.encoder_attn.v_proj.weight”, “model.model.decoder.layers.1.encoder_attn.v_proj.bias”, “model.model.decoder.layers.1.encoder_attn.q_proj.weight”, “model.model.decoder.layers.1.encoder_attn.q_proj.bias”, “model.model.decoder.layers.1.encoder_attn.out_proj.weight”, “model.model.decoder.layers.1.encoder_attn.out_proj.bias”, “model.model.decoder.layers.1.encoder_attn_layer_norm.weight”, “model.model.decoder.layers.1.encoder_attn_layer_norm.bias”, “model.model.decoder.layers.1.fc1.weight”, “model.model.decoder.layers.1.fc1.bias”, “model.model.decoder.layers.1.fc2.weight”, “model.model.decoder.layers.1.fc2.bias”, “model.model.decoder.layers.1.final_layer_norm.weight”, “model.model.decoder.layers.1.final_layer_norm.bias”, “model.model.decoder.layernorm_embedding.weight”, “model.model.decoder.layernorm_embedding.bias”, “model.lm_head.weight”.
I figured that the issue is the extra “model.” in the beginning, which is stopping the HF model from loading the state_dict correctly. I have then updated the on_load_checkpoint
method to delete it from each state_dict key. This works:
def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
model_config = BartConfig.from_dict(checkpoint["hyper_parameters"].model)
self.model = BartForConditionalGeneration(model_config)
# Delete only first "model." from start of keys
new_state_dict = {}
for key, value in checkpoint["state_dict"].items():
if key.startswith("model."):
new_key = key.split(".")[1:]
new_key = ".".join(new_key)
new_state_dict[new_key] = value
self.model.load_state_dict(new_state_dict)
try:
torch.compile(self.model)
except:
logger.error("Error compiling model!")
However, it doesn’t really seem efficient to me, mainly because we’re loading the state_dict twice in memory. Is there a different way I should approach this? I tried popping each key in the old dict and recreating the key in the checkpoint, but Lightning didn’t seem to like that.