Finetuning a model from the CLI (overwriting optimizer states, etc)


I’m using the Lightning CLI with a yaml configuration file to govern my training regime with

python3 --config config.yaml

I pretrained a model for 100 epochs with the SGD optimizer and a Step Learning Rate Scheduler with an initial learning rate of 1e-03, a step size of 2, and a gamma of 0.9.

My best weights were saved as Pytorch Lightning .ckpt files.

Excerpt from pretraining config file:

  class_path: OldDataModule
# ckpt_path: 
  class_path: torch.optim.SGD                   # Optimizer. See for list of available optimizers
    lr: 1e-03
  class_path: torch.optim.lr_scheduler.StepLR   # Learning rate scheduler. See for list of available schedulers
    step_size: 2                                # Step size in number of epochs for learning rate decay
    gamma: 0.90                                 # Learning rate decay


Now, I’d like to finetune the model. I change config.yaml to:

  • use a new LightningDataModule
  • to load a checkpoint with the ckpt parameter
  • and to use a new initial learning rate for the optimizer and learning rate scheduler

However, training resumes from the last epoch that the checkpoint was saved including using the learning rate stored within the checkpoint and not the new one.

New excerpt of my finetuning config file:

  class_path: NewDataModule
ckpt_path: ./logs/pretrained.ckpt
  class_path: torch.optim.SGD                   # Optimizer. See for list of available optimizers
    lr: 5e-08
  class_path: torch.optim.lr_scheduler.ExponentialLR
    gamma: 0.10                                 # Learning rate decay


What is the best way for me to load in a checkpoint (essentially warmstarting the model) and reset the optimizer using yaml config files and the Lightning CLI?


I reset the learning rate manually by creating a custom callback with the on_load_checkpoint hook using:

from lightning.pytorch.callbacks import Callback
class OverwriteLR(Callback):
    def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint) -> None:
        trainer.checkpoint['optimizer_states'][0]['param_groups'][0]['lr'] =  1e-08

I’d prefer not to do this because the value is hardcoded.

I’d prefer to tune or change all of my hyperparameters from the yaml file, however, the optimizer, scheduler, and learning rate within my yaml configuration file is not exposed to this callback and I couldn’t quite figure out how to expose it.

I’d be grateful for any ideas on this matter. Thanks!

Version: Lightning 2.0.2

I am facing the same problem, did you find a proper way to do this instead of using this callback?

Unfortunately not. I had to manually overwrite the learning rate and scheduler state by hardcoding the value in the callback.

Version: Lightning 2.1.0

