Hi there,
I’m trying to get a basic example of LR finder working (on MNIST) so that I can then use it for more complicated models. Unfortunately, the log plot of loss against LR seems wrong as the loss increases sharply even for very small LR (e.g. 10e-8).
I’ve created a script to run the LR finder, and another to plot the results.
LR Finder
import json
import os
from pytorch_lightning import LightningModule, Trainer
import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
BATCH_SIZE = 256 if torch.cuda.is_available() else 64
class MNISTModel(LightningModule):
def __init__(self):
super().__init__()
self.lr = 1e-3
self.l1 = torch.nn.Linear(28 * 28, 10)
def forward(self, x):
return torch.relu(self.l1(x.view(x.size(0), -1)))
def training_step(self, batch, _):
x, y = batch
loss = F.cross_entropy(self(x), y)
return loss
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.lr)
model = MNISTModel()
train_ds = MNIST(PATH_DATASETS, train=True, download=True, transform=ToTensor())
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE)
trainer = Trainer(
accelerator='gpu',
devices=1,
precision=16)
lr_finder = trainer.tuner.lr_find(model, train_dataloaders=train_loader)
filename = 'results.json'
with open(filename, 'w') as f:
f.write(json.dumps(lr_finder.results))
Plotting
import json
import matplotlib.pyplot as plt
import numpy as np
import os
filepath = os.path.join(os.environ['MYMI_CODE'], 'results.json')
results = json.load(open(filepath))
plt.xscale('log')
plt.plot(results['lr'], results['loss'])
Results
Environment
Collecting environment information...
PyTorch version: 1.13.1+cu117
Is debug build: False
CUDA used to build PyTorch: 11.7
ROCM used to build PyTorch: N/A
OS: Red Hat Enterprise Linux Server release 7.9 (Maipo) (x86_64)
GCC version: (GCC) 10.2.0
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.17
Python version: 3.8.6 (default, Mar 29 2021, 14:28:48) [GCC 10.2.0] (64-bit runtime)
Python platform: Linux-3.10.0-1160.66.1.el7.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to: LAZY
GPU models and configuration: GPU 0: NVIDIA A100 80GB PCIe
Nvidia driver version: 515.65.01
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.24.0
[pip3] pytorch-lightning==1.8.6
[pip3] torch==1.13.1
[pip3] torchaudio==0.13.1
[pip3] torchio==0.18.86
[pip3] torchmetrics==0.11.0
[pip3] torchvision==0.14.1
[conda] Could not collect
Any ideas on how to fix this script?
Thanks!
Brett