How to save new lr hyperparameter after using LRFinder when using wandb

I am using wandb as my logger and I am implementing the lr finder with lighting 2.0.4. When I check the saved hyperparameters in wandb after running my code, the saved lr is not the one that was selected by the lr finder. How can should I modify my code so that the new lr is saved by wandb? I am trying to avoid using the wandb.init() because I am using ddp.

Here is my main function for reference:

import matplotlib.pyplot as plt
import yaml
import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import torch
from datetime import datetime as dt
import wandb
from pprint import pprint
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint
import sys

# My code
from base.trainer import LitModel
from base.data import PointCloudDataModule


def main(cfg):
    # Setting all the random seeds to the same value.
    # This is important in a distributed training setting.
    # Each rank will get its own set of initial weights.
    # If they don't match up, the gradients will not match either,
    # leading to training that may not converge.
    pl.seed_everything(1)

    # Load data
    data_module = PointCloudDataModule(data_cfg=cfg['data'], batch_size=cfg['hp']['batch_size'])

    # Wrap in lightning module
    model = LitModel(
        lr=cfg['hp']['lr'],
        batch_size=cfg['hp']['batch_size'],
        dropout=cfg['hp']['dropout'],
        loss_function_type=cfg['hp']['loss_function_type'],
        weight_decay=cfg['hp']['weight_decay'],
        cawr_t_0=cfg['hp']['cawr_t_0'],
        cawr_t_mult=cfg['hp']['cawr_t_mult'],
        ocnn_stages=cfg['model']['ocnn_stages'],
        ocnn_use_additional_features=cfg['data']['ocnn_use_feats'],
        use_normals=cfg['data']['use_normals'],
        data_dir=cfg['data']['data_dir']
    )

    # Explicitly specify the process group backend if you choose to
    ddp = DDPStrategy(process_group_backend="gloo")

    # Set matrix multiplication precision https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
    torch.set_float32_matmul_precision('medium')

    # Set empty callback list
    callback_list = list()

    # Set up wandb logger
    current_dt = dt.now().strftime("%Y_%m_%d_%H_%M_%S")
    if cfg['hp']['num_epochs'] > 5:
        # More info about logging: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
        # Best practices for wandb: https://wandb.ai/wandb/pytorch-lightning-e2e/reports/W-B-Best-Practices-Guide--VmlldzozNTU1ODY1
        with open(cfg['misc']['wandb_key']) as f:
            wandb_key = f.readlines()[0]

        # Get datetime for labelling run
        wandb.login(key=wandb_key)
        run_name = f"bs-{cfg['hp']['batch_size']}_lr-{cfg['hp']['lr']}_dt-{current_dt}"
        wandb_logger = WandbLogger(name=run_name,
                                   version=run_name,
                                   id=run_name,
                                   log_model="all",
                                   project="RQ2_pc_dl_data_fusion",
                                   offline=False,
                                   experiment=None,
                                   checkpoint_name=None)

        # Set up checkpointing
        ckpt_path = f"checkpoints/run_{current_dt}"
        cfg['hp']['ckpt_path'] = ckpt_path
        checkpoint_callback = ModelCheckpoint(monitor="val_loss",
                                              mode="min",
                                              auto_insert_metric_name=True,
                                              dirpath=ckpt_path
                                              )
        callback_list.append(checkpoint_callback)

        # Set up lr logging
        callback_list.append(LearningRateMonitor(logging_interval='step'))
    else:
        wandb_logger = False

    # Add custom progress bar to callbacks list
    class MyProgressBar(TQDMProgressBar):
        def init_validation_tqdm(self):
            bar = super().init_validation_tqdm()
            if not sys.stdout.isatty():
                bar.disable = True
            return bar

        def init_predict_tqdm(self):
            bar = super().init_predict_tqdm()
            if not sys.stdout.isatty():
                bar.disable = True
            return bar

        def init_test_tqdm(self):
            bar = super().init_test_tqdm()
            if not sys.stdout.isatty():
                bar.disable = True
            return bar

    callback_list.append(MyProgressBar())

    #Add early stopping to callback list
    callback_list.append(EarlyStopping(monitor="val_loss", mode="min", min_delta=0.1, patience=10))

    # Set callback list to None if it contains zero elements
    if len(callback_list) == 0:
        callback_list = None

    # Set up trainer
    # Ref docs for lightning https://lightning.ai/docs/pytorch/stable/common/trainer.html
    trainer = pl.Trainer(

        # * * * * Parallelization and resources
        strategy=ddp,
        accelerator="gpu",
        benchmark=True,  # torch.backends.cudnn.benchmark
        # Implementing FP16 Mixed Precision Training https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html
        precision="16-mixed",

        # * * * * Duration Params
        max_epochs=cfg['hp']['num_epochs'],
        max_time=None,
        limit_train_batches=cfg['data']['partition_train'],

        # * * * * Logging
        enable_progress_bar=True,
        logger=wandb_logger,
        default_root_dir="lr_finder_logs",

        # * * * * Debugging
        fast_dev_run=False,  # Run 1 batch for debugging
        profiler=None,  # Profiles the train epoch base to identify bottlenecks set to None to turn off
        detect_anomaly=False,  # Enable anomaly detection for the autograd engine

        # * * * * Callbacks
        callbacks=callback_list,
    )

    print(f"Config:\n")
    pprint(cfg)

    if cfg['hp']['lr_finder']:
        # Implement auto lr finder
        tuner = Tuner(trainer)
        lr_finder = tuner.lr_find(model,
                                  min_lr=1e-5,
                                  max_lr=1,
                                  num_training=15,
                                  datamodule=data_module,
                                  update_attr=True,
                                  )
        new_lr = lr_finder.suggestion()
        print(f"LR finder identified:\n{new_lr}\nas candidate learning rate")
        fig = lr_finder.plot(suggest=True)
        plot_fpath = "temp/lr_finder_plot.jpg"
        plt.savefig(plot_fpath)
        fig.show()

        #Update lr hp in wandb log
        wandb.config["lr"] = new_lr
        wandb.config.update()

        # Save plot as a artifact in wandb
        wandb_logger.log_image(key="lr_finder_plot", images=[plot_fpath])

    # Implement training
    trainer.fit(model, datamodule=data_module)

    # Test model

    # TODO: implement testing as described here: https://github.com/Lightning-AI/lightning/issues/8375


if __name__ == "__main__":
    # Read config
    with open("config.yaml", "r") as yamlfile:
        cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)

    main(cfg=cfg)

Hey @harryseely
This is a known limitation and was already reported: `save_hyperparameters` does not save the selected` learning_rate` when `auto_lr_find` is used · Issue #15928 · Lightning-AI/lightning · GitHub

The only workaround I can think of is to manually update the learning rate in your wandb config via

wandb.config.update({"lr": self.learning_rate})

Yes this works great, thank you @awaelchli!

1 Like