Finetuning a model from the CLI (overwriting optimizer states, etc)

Premise:

I’m using the Lightning CLI with a yaml configuration file to govern my training regime with

python3 main.py --config config.yaml

I pretrained a model for 100 epochs with the SGD optimizer and a Step Learning Rate Scheduler with an initial learning rate of 1e-03, a step size of 2, and a gamma of 0.9.

My best weights were saved as Pytorch Lightning .ckpt files.

Excerpt from pretraining config file:

data:
  class_path: OldDataModule
# ckpt_path: 
optimizer:
  class_path: torch.optim.SGD                   # Optimizer. See https://pytorch.org/docs/stable/optim.html for list of available optimizers
  init_args:
    lr: 1e-03
lr_scheduler:
  class_path: torch.optim.lr_scheduler.StepLR   # Learning rate scheduler. See https://pytorch.org/docs/stable/optim.html for list of available schedulers
  init_args:
    step_size: 2                                # Step size in number of epochs for learning rate decay
    gamma: 0.90                                 # Learning rate decay

Problem:

Now, I’d like to finetune the model. I change config.yaml to:

  • use a new LightningDataModule
  • to load a checkpoint with the ckpt parameter
  • and to use a new initial learning rate for the optimizer and learning rate scheduler

However, training resumes from the last epoch that the checkpoint was saved including using the learning rate stored within the checkpoint and not the new one.

New excerpt of my finetuning config file:

data:
  class_path: NewDataModule
ckpt_path: ./logs/pretrained.ckpt
optimizer:
  class_path: torch.optim.SGD                   # Optimizer. See https://pytorch.org/docs/stable/optim.html for list of available optimizers
  init_args:
    lr: 5e-08
lr_scheduler:
  class_path: torch.optim.lr_scheduler.ExponentialLR
  init_args:
    gamma: 0.10                                 # Learning rate decay

Question:

What is the best way for me to load in a checkpoint (essentially warmstarting the model) and reset the optimizer using yaml config files and the Lightning CLI?

Attempts:

I reset the learning rate manually by creating a custom callback with the on_load_checkpoint hook using:

from lightning.pytorch.callbacks import Callback
class OverwriteLR(Callback):
    def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint) -> None:
        trainer.checkpoint['optimizer_states'][0]['param_groups'][0]['lr'] =  1e-08

I’d prefer not to do this because the value is hardcoded.

I’d prefer to tune or change all of my hyperparameters from the yaml file, however, the optimizer, scheduler, and learning rate within my yaml configuration file is not exposed to this callback and I couldn’t quite figure out how to expose it.

I’d be grateful for any ideas on this matter. Thanks!

Version: Lightning 2.0.2

I am facing the same problem, did you find a proper way to do this instead of using this callback?

Unfortunately not. I had to manually overwrite the learning rate and scheduler state by hardcoding the value in the callback.

Version: Lightning 2.1.0

1 Like