Failed to load checkpoint on external model

Hi! I have followed the guide here to train a BART model from HF with FSDP. Here is how the module looks like:

class BartModule(L.LightningModule):
    def __init__(self, config: "DictConfig"):
        super().__init__()
        self.save_hyperparameters(config)
        self.config = config

    def configure_model(self) -> None:
        model_config = BartConfig.from_dict(self.config.model)
        self.model = BartForConditionalGeneration(model_config)

        try:
            torch.compile(self.model)
            logger.info("Model compiled successfully!")
        except Exception as e:
            logger.error(f"Error compiling model: {e}\nSkipping compilation...")

    def forward(self, **inputs):
        return self.model(**inputs)

    def training_step(self, batch, batch_idx):
        ...

The model trains, and produces checkpoints in the specified directory. However, when I try to load one using BartModule.load_from_checkpoint(path_to_checkpoint), I get the following error:

Runtime error details

RuntimeError: Error(s) in loading state_dict for BartModule:
Unexpected key(s) in state_dict: “model.final_logits_bias”, “model.model.shared.weight”, “model.model.encoder.embed_tokens.weight”, “model.model.encoder.embed_positions.weight”, “model.model.encoder.layers.0.self_attn.k_proj.weight”, “model.model.encoder.layers.0.self_attn.k_proj.bias”, “model.model.encoder.layers.0.self_attn.v_proj.weight”, “model.model.encoder.layers.0.self_attn.v_proj.bias”, “model.model.encoder.layers.0.self_attn.q_proj.weight”, “model.model.encoder.layers.0.self_attn.q_proj.bias”, “model.model.encoder.layers.0.self_attn.out_proj.weight”, “model.model.encoder.layers.0.self_attn.out_proj.bias”, “model.model.encoder.layers.0.self_attn_layer_norm.weight”, “model.model.encoder.layers.0.self_attn_layer_norm.bias”, “model.model.encoder.layers.0.fc1.weight”, “model.model.encoder.layers.0.fc1.bias”, “model.model.encoder.layers.0.fc2.weight”, “model.model.encoder.layers.0.fc2.bias”, “model.model.encoder.layers.0.final_layer_norm.weight”, “model.model.encoder.layers.0.final_layer_norm.bias”, “model.model.encoder.layers.1.self_attn.k_proj.weight”, “model.model.encoder.layers.1.self_attn.k_proj.bias”, “model.model.encoder.layers.1.self_attn.v_proj.weight”, “model.model.encoder.layers.1.self_attn.v_proj.bias”, “model.model.encoder.layers.1.self_attn.q_proj.weight”, “model.model.encoder.layers.1.self_attn.q_proj.bias”, “model.model.encoder.layers.1.self_attn.out_proj.weight”, “model.model.encoder.layers.1.self_attn.out_proj.bias”, “model.model.encoder.layers.1.self_attn_layer_norm.weight”, “model.model.encoder.layers.1.self_attn_layer_norm.bias”, “model.model.encoder.layers.1.fc1.weight”, “model.model.encoder.layers.1.fc1.bias”, “model.model.encoder.layers.1.fc2.weight”, “model.model.encoder.layers.1.fc2.bias”, “model.model.encoder.layers.1.final_layer_norm.weight”, “model.model.encoder.layers.1.final_layer_norm.bias”, “model.model.encoder.layernorm_embedding.weight”, “model.model.encoder.layernorm_embedding.bias”, “model.model.decoder.embed_tokens.weight”, “model.model.decoder.embed_positions.weight”, “model.model.decoder.layers.0.self_attn.k_proj.weight”, “model.model.decoder.layers.0.self_attn.k_proj.bias”, “model.model.decoder.layers.0.self_attn.v_proj.weight”, “model.model.decoder.layers.0.self_attn.v_proj.bias”, “model.model.decoder.layers.0.self_attn.q_proj.weight”, “model.model.decoder.layers.0.self_attn.q_proj.bias”, “model.model.decoder.layers.0.self_attn.out_proj.weight”, “model.model.decoder.layers.0.self_attn.out_proj.bias”, “model.model.decoder.layers.0.self_attn_layer_norm.weight”, “model.model.decoder.layers.0.self_attn_layer_norm.bias”, “model.model.decoder.layers.0.encoder_attn.k_proj.weight”, “model.model.decoder.layers.0.encoder_attn.k_proj.bias”, “model.model.decoder.layers.0.encoder_attn.v_proj.weight”, “model.model.decoder.layers.0.encoder_attn.v_proj.bias”, “model.model.decoder.layers.0.encoder_attn.q_proj.weight”, “model.model.decoder.layers.0.encoder_attn.q_proj.bias”, “model.model.decoder.layers.0.encoder_attn.out_proj.weight”, “model.model.decoder.layers.0.encoder_attn.out_proj.bias”, “model.model.decoder.layers.0.encoder_attn_layer_norm.weight”, “model.model.decoder.layers.0.encoder_attn_layer_norm.bias”, “model.model.decoder.layers.0.fc1.weight”, “model.model.decoder.layers.0.fc1.bias”, “model.model.decoder.layers.0.fc2.weight”, “model.model.decoder.layers.0.fc2.bias”, “model.model.decoder.layers.0.final_layer_norm.weight”, “model.model.decoder.layers.0.final_layer_norm.bias”, “model.model.decoder.layers.1.self_attn.k_proj.weight”, “model.model.decoder.layers.1.self_attn.k_proj.bias”, “model.model.decoder.layers.1.self_attn.v_proj.weight”, “model.model.decoder.layers.1.self_attn.v_proj.bias”, “model.model.decoder.layers.1.self_attn.q_proj.weight”, “model.model.decoder.layers.1.self_attn.q_proj.bias”, “model.model.decoder.layers.1.self_attn.out_proj.weight”, “model.model.decoder.layers.1.self_attn.out_proj.bias”, “model.model.decoder.layers.1.self_attn_layer_norm.weight”, “model.model.decoder.layers.1.self_attn_layer_norm.bias”, “model.model.decoder.layers.1.encoder_attn.k_proj.weight”, “model.model.decoder.layers.1.encoder_attn.k_proj.bias”, “model.model.decoder.layers.1.encoder_attn.v_proj.weight”, “model.model.decoder.layers.1.encoder_attn.v_proj.bias”, “model.model.decoder.layers.1.encoder_attn.q_proj.weight”, “model.model.decoder.layers.1.encoder_attn.q_proj.bias”, “model.model.decoder.layers.1.encoder_attn.out_proj.weight”, “model.model.decoder.layers.1.encoder_attn.out_proj.bias”, “model.model.decoder.layers.1.encoder_attn_layer_norm.weight”, “model.model.decoder.layers.1.encoder_attn_layer_norm.bias”, “model.model.decoder.layers.1.fc1.weight”, “model.model.decoder.layers.1.fc1.bias”, “model.model.decoder.layers.1.fc2.weight”, “model.model.decoder.layers.1.fc2.bias”, “model.model.decoder.layers.1.final_layer_norm.weight”, “model.model.decoder.layers.1.final_layer_norm.bias”, “model.model.decoder.layernorm_embedding.weight”, “model.model.decoder.layernorm_embedding.bias”, “model.lm_head.weight”.

I figured that the issue is the extra “model.” in the beginning, which is stopping the HF model from loading the state_dict correctly. I have then updated the on_load_checkpoint method to delete it from each state_dict key. This works:

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
        model_config = BartConfig.from_dict(checkpoint["hyper_parameters"].model)
        self.model = BartForConditionalGeneration(model_config)

        # Delete only first "model." from start of keys
        new_state_dict = {}
        for key, value in checkpoint["state_dict"].items():
            if key.startswith("model."):
                new_key = key.split(".")[1:]
                new_key = ".".join(new_key)
                new_state_dict[new_key] = value

        self.model.load_state_dict(new_state_dict)

        try:
            torch.compile(self.model)
        except:
            logger.error("Error compiling model!")

However, it doesn’t really seem efficient to me, mainly because we’re loading the state_dict twice in memory. Is there a different way I should approach this? I tried popping each key in the old dict and recreating the key in the checkpoint, but Lightning didn’t seem to like that.

@cavoinea We’ve implemented

which should help with such issues in the future. The problem here was simply that your model isn’t fully defined when initializing it, as you would first have to call configure_model before you can load the weights. And so the linked PR does that automatically now.