I used callbacks to save my best model checkpoint, however when I try to load the model in a new session and restore the best weights it keeps throwing errors. This is my code:
class DroneModel(pl.LightningModule):
def __init__(self, model, optimizer, criterion):
super().__init__()
self.model = model
self.criterion = criterion
self.optimizer = optimizer
cbs = pl.callbacks.ModelCheckpoint(dirpath = f'./checkpoints_{arch}',
filename = arch,
verbose = True,
monitor = 'valid_loss',
mode = 'min')
pl_model = DroneModel(model, optimizer, criterion)
trainer = pl.Trainer(callbacks=cbs, accelerator='gpu', max_epochs=25, auto_lr_find=True)
trainer.fit(pl_model, train_dl, val_dl)
model = DroneModel.load_state_dict(checkpoint_path, model=model, optimizer=optimizer, criterion=criterion)
And it gives me this error:
@santurini Can you show us the implementation of DroneModel.load_state_dict? It looks like you have overridden it. Make sure you are calling the load_state_dict on the right objects. The Lightning checkpoint will contain the state dict with module names starting with the top modules in the definition of the LightningModule.
Hello Adrian I solved the problem and I will post here the solution if it may help.
I was trying to load the model like this:
model = smp.Unet()
model.load_state_dict(torch.load(PATH_TO_BEST_MODEL)['state_dict'])
This code was giving me the error in the figure and the problem was that in the state dict saved by the callbacks the name of the keys were:
model.encoder._conv_stem.weight != encoder._conv_stem.weight
So my simple solution was this one:
model = smp.create_model(arch,
encoder_name = enc_name,
encoder_weights = "imagenet",
in_channels = 3,
classes = classes).to(device)
state_dict = torch.load(cbs.best_model_path)['state_dict']
pl_state_dict = OrderedDict([(key[6:], state_dict[key]) for key in state_dict.keys()])
model.load_state_dict(pl_state_dict)
model.eval()
The testing loop was implemented by hand for this reason I used to(device)
and model.eval()
1 Like