Premise:
I’m using the Lightning CLI with a yaml configuration file to govern my training regime with
python3 main.py --config config.yaml
I pretrained a model for 100 epochs with the SGD optimizer and a Step Learning Rate Scheduler with an initial learning rate of 1e-03
, a step size of 2, and a gamma of 0.9.
My best weights were saved as Pytorch Lightning .ckpt
files.
Excerpt from pretraining config file:
data:
class_path: OldDataModule
# ckpt_path:
optimizer:
class_path: torch.optim.SGD # Optimizer. See https://pytorch.org/docs/stable/optim.html for list of available optimizers
init_args:
lr: 1e-03
lr_scheduler:
class_path: torch.optim.lr_scheduler.StepLR # Learning rate scheduler. See https://pytorch.org/docs/stable/optim.html for list of available schedulers
init_args:
step_size: 2 # Step size in number of epochs for learning rate decay
gamma: 0.90 # Learning rate decay
Problem:
Now, I’d like to finetune the model. I change config.yaml
to:
- use a new LightningDataModule
- to load a checkpoint with the
ckpt
parameter - and to use a new initial learning rate for the optimizer and learning rate scheduler
However, training resumes from the last epoch that the checkpoint was saved including using the learning rate stored within the checkpoint and not the new one.
New excerpt of my finetuning config file:
data:
class_path: NewDataModule
ckpt_path: ./logs/pretrained.ckpt
optimizer:
class_path: torch.optim.SGD # Optimizer. See https://pytorch.org/docs/stable/optim.html for list of available optimizers
init_args:
lr: 5e-08
lr_scheduler:
class_path: torch.optim.lr_scheduler.ExponentialLR
init_args:
gamma: 0.10 # Learning rate decay
Question:
What is the best way for me to load in a checkpoint (essentially warmstarting the model) and reset the optimizer using yaml
config files and the Lightning CLI?
Attempts:
I reset the learning rate manually by creating a custom callback with the on_load_checkpoint
hook using:
from lightning.pytorch.callbacks import Callback
class OverwriteLR(Callback):
def on_load_checkpoint(self, trainer: Trainer, pl_module: LightningModule, checkpoint) -> None:
trainer.checkpoint['optimizer_states'][0]['param_groups'][0]['lr'] = 1e-08
I’d prefer not to do this because the value is hardcoded.
I’d prefer to tune or change all of my hyperparameters from the yaml
file, however, the optimizer, scheduler, and learning rate within my yaml configuration file is not exposed to this callback and I couldn’t quite figure out how to expose it.
I’d be grateful for any ideas on this matter. Thanks!
Version: Lightning 2.0.2