Change/reset ModelCheckpoint.best_model_score upon loading checkpoint

Hey all,

Context: I want to train a pre-trained model on a different data set. I have 2 ModelCheckpoint callbacks that save best models according to some metrics. I load the pre-trained model using trainer.fit(..., ckpt_path='my_checkpoint_file'). The callbacks are part of the checkpoint file → during re-training ModelCheckpoint.best_model_score is recovered from pre-training → during re-training ModelCheckpoints only save new models if re-training metrics exceed scores of the pre-training metrics.

Intended solution: I want the callbacks to only consider best_model_score on the re-training data set.

My solution: Have a callback that manually resets best_model_score if a checkpoint is provided, i.e.

from pytorch_lightning.callbacks import Callback
import torch


class MyCallback(Callback):
    def on_fit_start(self, trainer, pl_module):
        if trainer.callbacks[4].best_model_score is not None:
            trainer.callbacks[4].best_model_score = torch.tensor(1.)

Example:
best_model_score in checkpoint file = 0.5
best_model_score after MyCallback = 1.0
model_score during validation = 0.7 → save new best model

Problem: Although the best_model_score is reset (confirmed during validation step), the best_model_score prior to MyCallback is considered

Example:
best_model_score in checkpoint file = 0.5
best_model_score after MyCallback = 1.0
model_score during validation = 0.7 → does not save new best model (best model has 0.5 score)

My system:

python==3.8.7
torch==1.10.2
pytorch-lightning==1.5.9

I am glad for any hint.

Best,
dsethz

It seems you do not have to manually set best_model_score, but the values of ModelCheckpoint.best_k_models, i.e.:

from pytorch_lightning.callbacks import Callback
import torch


class MyCallback(Callback):
    def on_fit_start(self, trainer, pl_module):
        if trainer.callbacks[4].best_model_score is not None:
            trainer.callbacks[4].best_k_models['path_to_model'] = torch.tensor(1.)

where ‘path_to_model’ is usually the path to the checkpoint file