I am using wandb as my logger and I am implementing the lr finder with lighting 2.0.4. When I check the saved hyperparameters in wandb after running my code, the saved lr is not the one that was selected by the lr finder. How can should I modify my code so that the new lr is saved by wandb? I am trying to avoid using the wandb.init() because I am using ddp.
Here is my main function for reference:
import matplotlib.pyplot as plt
import yaml
import lightning.pytorch as pl
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.strategies import DDPStrategy
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.callbacks import LearningRateMonitor
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
import torch
from datetime import datetime as dt
import wandb
from pprint import pprint
from lightning.pytorch.callbacks import TQDMProgressBar, ModelCheckpoint
import sys
# My code
from base.trainer import LitModel
from base.data import PointCloudDataModule
def main(cfg):
# Setting all the random seeds to the same value.
# This is important in a distributed training setting.
# Each rank will get its own set of initial weights.
# If they don't match up, the gradients will not match either,
# leading to training that may not converge.
pl.seed_everything(1)
# Load data
data_module = PointCloudDataModule(data_cfg=cfg['data'], batch_size=cfg['hp']['batch_size'])
# Wrap in lightning module
model = LitModel(
lr=cfg['hp']['lr'],
batch_size=cfg['hp']['batch_size'],
dropout=cfg['hp']['dropout'],
loss_function_type=cfg['hp']['loss_function_type'],
weight_decay=cfg['hp']['weight_decay'],
cawr_t_0=cfg['hp']['cawr_t_0'],
cawr_t_mult=cfg['hp']['cawr_t_mult'],
ocnn_stages=cfg['model']['ocnn_stages'],
ocnn_use_additional_features=cfg['data']['ocnn_use_feats'],
use_normals=cfg['data']['use_normals'],
data_dir=cfg['data']['data_dir']
)
# Explicitly specify the process group backend if you choose to
ddp = DDPStrategy(process_group_backend="gloo")
# Set matrix multiplication precision https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html
torch.set_float32_matmul_precision('medium')
# Set empty callback list
callback_list = list()
# Set up wandb logger
current_dt = dt.now().strftime("%Y_%m_%d_%H_%M_%S")
if cfg['hp']['num_epochs'] > 5:
# More info about logging: https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.loggers.wandb.html
# Best practices for wandb: https://wandb.ai/wandb/pytorch-lightning-e2e/reports/W-B-Best-Practices-Guide--VmlldzozNTU1ODY1
with open(cfg['misc']['wandb_key']) as f:
wandb_key = f.readlines()[0]
# Get datetime for labelling run
wandb.login(key=wandb_key)
run_name = f"bs-{cfg['hp']['batch_size']}_lr-{cfg['hp']['lr']}_dt-{current_dt}"
wandb_logger = WandbLogger(name=run_name,
version=run_name,
id=run_name,
log_model="all",
project="RQ2_pc_dl_data_fusion",
offline=False,
experiment=None,
checkpoint_name=None)
# Set up checkpointing
ckpt_path = f"checkpoints/run_{current_dt}"
cfg['hp']['ckpt_path'] = ckpt_path
checkpoint_callback = ModelCheckpoint(monitor="val_loss",
mode="min",
auto_insert_metric_name=True,
dirpath=ckpt_path
)
callback_list.append(checkpoint_callback)
# Set up lr logging
callback_list.append(LearningRateMonitor(logging_interval='step'))
else:
wandb_logger = False
# Add custom progress bar to callbacks list
class MyProgressBar(TQDMProgressBar):
def init_validation_tqdm(self):
bar = super().init_validation_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
def init_predict_tqdm(self):
bar = super().init_predict_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
def init_test_tqdm(self):
bar = super().init_test_tqdm()
if not sys.stdout.isatty():
bar.disable = True
return bar
callback_list.append(MyProgressBar())
#Add early stopping to callback list
callback_list.append(EarlyStopping(monitor="val_loss", mode="min", min_delta=0.1, patience=10))
# Set callback list to None if it contains zero elements
if len(callback_list) == 0:
callback_list = None
# Set up trainer
# Ref docs for lightning https://lightning.ai/docs/pytorch/stable/common/trainer.html
trainer = pl.Trainer(
# * * * * Parallelization and resources
strategy=ddp,
accelerator="gpu",
benchmark=True, # torch.backends.cudnn.benchmark
# Implementing FP16 Mixed Precision Training https://lightning.ai/docs/pytorch/stable/common/precision_intermediate.html
precision="16-mixed",
# * * * * Duration Params
max_epochs=cfg['hp']['num_epochs'],
max_time=None,
limit_train_batches=cfg['data']['partition_train'],
# * * * * Logging
enable_progress_bar=True,
logger=wandb_logger,
default_root_dir="lr_finder_logs",
# * * * * Debugging
fast_dev_run=False, # Run 1 batch for debugging
profiler=None, # Profiles the train epoch base to identify bottlenecks set to None to turn off
detect_anomaly=False, # Enable anomaly detection for the autograd engine
# * * * * Callbacks
callbacks=callback_list,
)
print(f"Config:\n")
pprint(cfg)
if cfg['hp']['lr_finder']:
# Implement auto lr finder
tuner = Tuner(trainer)
lr_finder = tuner.lr_find(model,
min_lr=1e-5,
max_lr=1,
num_training=15,
datamodule=data_module,
update_attr=True,
)
new_lr = lr_finder.suggestion()
print(f"LR finder identified:\n{new_lr}\nas candidate learning rate")
fig = lr_finder.plot(suggest=True)
plot_fpath = "temp/lr_finder_plot.jpg"
plt.savefig(plot_fpath)
fig.show()
#Update lr hp in wandb log
wandb.config["lr"] = new_lr
wandb.config.update()
# Save plot as a artifact in wandb
wandb_logger.log_image(key="lr_finder_plot", images=[plot_fpath])
# Implement training
trainer.fit(model, datamodule=data_module)
# Test model
# TODO: implement testing as described here: https://github.com/Lightning-AI/lightning/issues/8375
if __name__ == "__main__":
# Read config
with open("config.yaml", "r") as yamlfile:
cfg = yaml.load(yamlfile, Loader=yaml.FullLoader)
main(cfg=cfg)