Changing the Optimizer and lr_scheduler with a callback

Hi,
I try to reset my optimizer and lr_scheduler once every n epochs via a callback. I found this outdated thread.
Similarly, I have access to the models initial settings for optimizer and scheduler, as defined by the configure_optimizers method.
Now, I would like to add to my callback something like:

def reset_optimizers(self, trainer, pl_module):
        orig_opt_dict = pl_module.configure_optimizers()
        trainer.optimizers = [orig_opt_dict["optimizer"]]
        trainer.lr_schedulers = [orig_opt_dict["lr_scheduler"]]

but unfortunately the attribute trainer.lr_schedulers does not exist. I figured out, that I have to change the
trainer.strategy.lr_scheduler_configs to change the scheduler. At least I believe that is happening, since the change of the optimizer (via the trainer.optimizers property and property.setter ) also just passes its argument on to the strategy.
In the mentioned thread, the call to

trainer.lr_schedulers = trainer._configure_schedulers(
        pt_dict["lr_scheduler"], 
        monitor="train/loss", 
        is_manual_optimization=False )

seems to do just that. However, unfortunately this method does not seem to exist in my (relatively) current version (2.0.9.post0) of lightning.
Any Ideas on how to accomplish the resetting?
Thanks in advance!

Hi hannesstagge

I have a similar use case and had to reset the optimizer and scheduler in a hyperparameter optimization context. My solution worked in Lightning 2.0.2 and 2.2.0.post0. I would set your scheduler from within your model’s (subclass of LightningModule). Then all you need to do is call MyModel.configure_optimizers() at whichever point you need to reset the optimizer and scheduler in one action. Here is the outline of how this function looks in my code:

class MyModel(L.LightningModule):
“”" Implements the training, validation and testing routines “”"

def __init__(self,  # your parameters to the module
             **kwargs):
    super().__init__()
    # your init code and other functions
    def configure_optimizers(self):
        """ Create and configures optimizer and learning rate schedule """
        # this method is called once by Lightning before training starts
        optimizer_params = self._get_optimizer_params()
        _lr = self.lr
        _wd = self.weight_decay
        optimizer = AdamW(optimizer_params, lr=_lr, weight_decay=_wd)

        # you may point to the optimizer so that it is easily accessible in your class (optional)
        self.optimizer = optimizer

        # Below is some logic that changes the lr-scheduler. Adapt as needed
        # In my case I have an option to disable the lr-scheduler if set to True. 
        if self.disable_lr_scheduling: # keep the lr constant
            scheduler = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0 total_iters=1,verbose=False)
        else: 
            _max_lr = self.lr
            scheduler = torch.optim.lr_scheduler.OneCycleLR(
                optimizer, max_lr=_max_lr, total_steps=self.trainer.estimated_stepping_batches,
                pct_start = self.lr_warmup
            )

        # Another optional reference so I can get to the optimizer from MyModel
        self.scheduler = scheduler
        return {"optimizer": optimizer, "lr_scheduler": { "scheduler": scheduler,
                "interval": "step",}}

You can call this function to reset the scheduler along with the optimizer.

Notice that the epoch/global_step in your trainer and the optimizer/scheduler, depending on your case, may need to be reset. If you need to reset the scheduler (and optimizer) at a desired point of your training/fine-tuning trajectory you may run a dummy loop starting from zero to the current global_step, recreating the progression of the scheduler/optimizer to the desired state. This is a rather naive way of doing it, and there may be simpler ways depending on the scheduler and optimizer used. Someone with more insight may weigh in on this. I actually have posted a question on how to set the global_step and epoch via a callback.

Hope this helps
Best