Loadind saved checkpoint model.model

Hi,
While trying to load a checkpoint,

model = EnhanceModel.load_from_checkpoint('.../saved_weights/epoch=45-step=154054.ckpt', model = some_model, hparams=hyperparams)

I encounter the following error:

RuntimeError: Error(s) in loading state_dict for EnhanceModel:
	Unexpected key(s) in state_dict: "model.encoder.conv0.0.weight", "model.encoder.conv0.0.bias", "model.encoder.conv0.1.weight",.....

My code init calls for using self.model for the model:

class EnhanceModel(pl.LightningModule):
    def __init__(self, hparams, model):
        super().__init__()
        self.hyperparams = hparams
        self.model = model
        self.disp_loss_func = sequence_loss

What will be the correct way to checkpoint in the future, and any suggestions on how to load the model?

Hey

Does your “self.model” actually have the model.encoder etc. layers defined? I can only guess, but from the error message it looks like you might have changed your code/model definition and it no longer matches what is in the checkpoint.

If you have removed some layers, and still want to load the weights into the layers that match the names, you can disable strict loading:

EnhanceModel.load_from_checkpoint(..., strict=False)

But first make sure that this is what you want :slight_smile:

Hi,

The model is the same model, it takes a model, which its state_dict is the same as in the error massage but without the ‘model’ in the beginning.
So instead of:

"model.encoder.conv0.0.weight", "model.encoder.conv0.0.bias", "model.encoder.conv0.1.weight",.....
"encoder.conv0.0.weight", "encoder.conv0.0.bias", "encoder.conv0.1.weight",.....

if I simply do:

new_state_dict = {}
for key in state_dict.keys():
    new_var = key[6:]
    new_state_dict[new_var] = state_dict[key]

The new_state_dict is ok.

I’m just wondering what will be the better way of checkpointing in the future.

The code is:

class EnhanceModel(pl.LightningModule):
    def __init__(self, hparams, model):
        super().__init__()
        self.hyperparams = hparams
        self.model = model
        self.disp_loss_func = sequence_loss
        ....

class net1(nn.Module):

    def __init__(self):
        super(net1, self).__init__()
        self.encoder = BasicEncoder()
        self.decoder =decoder()
    def forward(x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

model = EnhanceModel(hyperparams, net1())
trainer = pl.Trainer(gpus=1, max_epochs=hyperparams.epochs, logger=wandb_logger)
trainer.fit(model)