I am referring to the official documentation for manual optimization. To this end, I am trying to implement this using a custom learning-rate scheduler: linear warmup followed by step-decays. The code for this is:
# Get MNIST data-
path_to_files = "/home/amajumdar/Downloads/.data/"
batch_size = 512
train_dataset, test_dataset, train_loader, test_loader = mnist_dataset(path_to_files, batch_size = batch_size)
"""
Train model with Custom earning-rate Scheduler
Training dataset = 60000, batch size = 512, number of training steps/iterations per epoch = 60000 / 512 = 117.1875 = 117
After an initial linear learning rate warmup of 13 epochs or 1523 (13 * 117.1875 = 1523.4375) training steps:
1. For the next 7 epochs, or, until 20th epoch (20 * 117.1875 = 2343.75), use lr = 0.1.
2. For the next 5 epochs, or, until 25th epoch (25 * 117.1875 = 2929.6875), use lr = 0.01.
3. For remaining epochs, use lr = 0.001.
"""
boundaries = [2344, 2930]
values = [0.1, 0.01, 0.001]
def decay_function(
step, boundaries = [2344, 2930],
values = [0.1, 0.01, 0.001]
):
for idx, bound in enumerate(boundaries):
if step < bound:
return values[idx]
return values[-1]
class schedule():
def __init__(self, initial_learning_rate = 0.1, warmup_steps = 1000, decay_func = None):
self.initial_learning_rate = initial_learning_rate
self.warmup_steps = warmup_steps
self.decay_func = decay_func
self.warmup_step_size = initial_learning_rate/warmup_steps
self.current_lr = 0
def get_lr(self, step):
if step == 0:
return self.current_lr
elif step <= self.warmup_steps:
self.current_lr+= self.warmup_step_size
return self.current_lr
elif step > self.warmup_steps:
if self.decay_func:
return self.decay_func(step)
else:
return self.current_lr
# Initial linear LR warmup: 13 x 117.1875 = 1523.4375 in 13 epochs.
custom_lr_scheduler = schedule(
initial_learning_rate = 0.1, warmup_steps = 1523,
decay_func = decay_function
)
step = 0
# Define LightningModule-
class LeNet5_MNIST(pl.LightningModule):
def __init__(self, beta = 1.0):
super().__init__()
# Initialize an instance of LeNet-5 CNN architecture-
self.model = LeNet5(beta = beta)
# Apply weights initialization-
self.model.apply(init_weights)
# Important: This property activates manual optimization-
self.automatic_optimization = False
def compute_loss(self, batch):
x, y = batch
pred = self.model(x)
loss = F.cross_entropy(pred, y)
return loss
def validation_step(self, batch, batch_idx):
# Validation loop.
x_t, y_t = batch
out_t = self.model(x_t)
loss_t = F.cross_entropy(out_t, y_t)
running_corrects = 0.0
_, predicted_t = torch.max(out_t, 1)
running_corrects = torch.sum(predicted_t == y_t.data)
val_acc = (running_corrects.double() / len(y_t)) * 100
self.log('val_loss', loss_t)
self.log('val_acc', val_acc)
return {'loss_val': loss_t, 'val_acc': val_acc}
def training_step(self, batch, batch_idx):
'''
TIP:
Be careful where you call 'optimizer.zero_grad()', or your model won’t converge. It is
good practice to call 'optimizer.zero_grad()' before 'self.manual_backward(loss)'.
'''
opt = self.optimizers()
opt = opt.optimizer
opt.zero_grad()
# training_step() defines the training loop.
# It's independent of forward().
x, y = batch
pred = self.model(x)
loss = F.cross_entropy(pred, y)
# loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step()
# Use custom learning-rate scheduler-
global step
opt.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
step += 1
running_corrects = 0.0
_, predicted = torch.max(pred, 1)
running_corrects = torch.sum(predicted == y.data)
train_acc = (running_corrects.double() / len(y)) * 100
# log to Tensorboard (if installed) by default-
self.log('train_loss', loss, on_step = False, on_epoch = True)
self.log('train_acc', train_acc, on_step = False, on_epoch = True)
return {'loss': loss, 'train_acc': train_acc}
def on_after_backward(self):
# example to inspect gradient information in tensorboard-
# don't make the tf file huge
if self.trainer.global_step % 25 == 0:
for layer_name, param in self.named_parameters():
grad = param.grad
self.logger.experiment.add_histogram(
tag = layer_name, values = grad,
global_step = self.trainer.global_step
)
def configure_optimizers(self):
# optimizer = optim.Adam(params = self.parameters(), lr = 1e-3)
optimizer = torch.optim.SGD(
params = self.parameters(), lr = 0.0,
momentum = 0.9, weight_decay = 5e-4
)
return optimizer
model_cnn = LeNet5_MNIST(beta = 1.0)
# Checkpointing is enabled by default to the current working directory. To change the checkpoint
# path pass in-
path_to_ckpt = "/home/amajumdar/Documents/Codes/PyTorch_Lightning/checkpoints/"
# To modify the behavior of checkpointing pass in your own callback-
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
dirpath = os.getcwd(),
# filename = f'LeNet5_{epoch}-{val_acc:.2f}',
filename = 'LeNet5-mnist-{epoch:02d}-{val_acc:.2f}',
save_top_k = 1, verbose = True,
monitor = 'val_acc', mode = 'max',
# save_weights_only (bool) – if True, then only the model’s weights will be saved.
# Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
save_weights_only = False
)
# Learning-rate monitoring-
lr_monitor = LearningRateMonitor(logging_interval = 'step')
# trainer = Trainer(callbacks = [lr_monitor])
# Train the model-
trainer = pl.Trainer(
accelerator = 'cpu',
limit_train_batches = 1.0, limit_val_batches = 1.0,
max_epochs = 35, default_root_dir = path_to_ckpt,
callbacks = [checkpoint_callback, lr_monitor]
)
The manual optimization does not seem to save any checkpoint!! I found a similar open question here.
Help!