Loading best checkpoint throws error

I used callbacks to save my best model checkpoint, however when I try to load the model in a new session and restore the best weights it keeps throwing errors. This is my code:

class DroneModel(pl.LightningModule):
    def __init__(self, model, optimizer, criterion):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer

cbs = pl.callbacks.ModelCheckpoint(dirpath = f'./checkpoints_{arch}',
                                   filename = arch, 
                                   verbose = True, 
                                   monitor = 'valid_loss', 
                                   mode = 'min')

pl_model = DroneModel(model, optimizer, criterion)

trainer = pl.Trainer(callbacks=cbs, accelerator='gpu', max_epochs=25, auto_lr_find=True)
trainer.fit(pl_model, train_dl, val_dl)

model = DroneModel.load_state_dict(checkpoint_path, model=model, optimizer=optimizer, criterion=criterion)

And it gives me this error:

@santurini Can you show us the implementation of DroneModel.load_state_dict? It looks like you have overridden it. Make sure you are calling the load_state_dict on the right objects. The Lightning checkpoint will contain the state dict with module names starting with the top modules in the definition of the LightningModule.

Hello Adrian I solved the problem and I will post here the solution if it may help.
I was trying to load the model like this:

model = smp.Unet()

This code was giving me the error in the figure and the problem was that in the state dict saved by the callbacks the name of the keys were:

model.encoder._conv_stem.weight != encoder._conv_stem.weight

So my simple solution was this one:

model = smp.create_model(arch,
                         encoder_name = enc_name,
                         encoder_weights = "imagenet",
                         in_channels = 3,
                         classes = classes).to(device)

state_dict = torch.load(cbs.best_model_path)['state_dict']
pl_state_dict = OrderedDict([(key[6:], state_dict[key]) for key in state_dict.keys()])


The testing loop was implemented by hand for this reason I used to(device) and model.eval()

1 Like