Manual optimization prevents saving checkpoint

I am referring to the official documentation for manual optimization. To this end, I am trying to implement this using a custom learning-rate scheduler: linear warmup followed by step-decays. The code for this is:

# Get MNIST data-
path_to_files = "/home/amajumdar/Downloads/.data/"

batch_size = 512
train_dataset, test_dataset, train_loader, test_loader = mnist_dataset(path_to_files, batch_size = batch_size)


"""
Train model with Custom earning-rate Scheduler

Training dataset = 60000, batch size = 512, number of training steps/iterations per epoch = 60000 / 512 = 117.1875 = 117

After an initial linear learning rate warmup of 13 epochs or 1523 (13 * 117.1875 = 1523.4375) training steps:
1. For the next 7 epochs, or, until 20th epoch (20 * 117.1875 = 2343.75), use lr = 0.1.
2. For the next 5 epochs, or, until 25th epoch (25 * 117.1875 = 2929.6875), use lr = 0.01.
3. For remaining epochs, use lr = 0.001.
"""
boundaries = [2344, 2930]
values = [0.1, 0.01, 0.001]


def decay_function(
    step, boundaries = [2344, 2930],
    values = [0.1, 0.01, 0.001]
):
    for idx, bound in enumerate(boundaries):
        if step < bound:
            return values[idx]

    return values[-1]


class schedule():
    def __init__(self, initial_learning_rate = 0.1, warmup_steps = 1000, decay_func = None):
        self.initial_learning_rate = initial_learning_rate
        self.warmup_steps = warmup_steps
        self.decay_func = decay_func
        self.warmup_step_size = initial_learning_rate/warmup_steps
        self.current_lr = 0

    def get_lr(self, step):
        if step == 0:
            return self.current_lr
        elif step <= self.warmup_steps:
            self.current_lr+= self.warmup_step_size
            return self.current_lr
        elif step > self.warmup_steps:
            if self.decay_func:
                return self.decay_func(step)
        else:
            return self.current_lr


# Initial linear LR warmup: 13 x 117.1875 = 1523.4375 in 13 epochs.
custom_lr_scheduler = schedule(
    initial_learning_rate = 0.1, warmup_steps = 1523,
    decay_func = decay_function
)

step = 0


# Define LightningModule-
class LeNet5_MNIST(pl.LightningModule):
    def __init__(self, beta = 1.0):
        super().__init__()

        # Initialize an instance of LeNet-5 CNN architecture-
        self.model = LeNet5(beta = beta)

        # Apply weights initialization-
        self.model.apply(init_weights)

        # Important: This property activates manual optimization-
        self.automatic_optimization = False

    
    def compute_loss(self, batch):
        x, y = batch
        pred = self.model(x)
        loss = F.cross_entropy(pred, y)
        return loss


    def validation_step(self, batch, batch_idx):
        # Validation loop.
        x_t, y_t = batch
        out_t = self.model(x_t)
        loss_t = F.cross_entropy(out_t, y_t)

        running_corrects = 0.0
        _, predicted_t = torch.max(out_t, 1)
        running_corrects = torch.sum(predicted_t == y_t.data)
        val_acc = (running_corrects.double() / len(y_t)) * 100

        self.log('val_loss', loss_t)
        self.log('val_acc', val_acc)
        return {'loss_val': loss_t, 'val_acc': val_acc}


    def training_step(self, batch, batch_idx):
        '''
        TIP:
        Be careful where you call 'optimizer.zero_grad()', or your model won’t converge. It is
        good practice to call 'optimizer.zero_grad()' before 'self.manual_backward(loss)'.
        '''
        opt = self.optimizers()
        opt = opt.optimizer
        opt.zero_grad()

        # training_step() defines the training loop.
        # It's independent of forward().
        x, y = batch
        pred = self.model(x)
        loss = F.cross_entropy(pred, y)
        # loss = self.compute_loss(batch)
        self.manual_backward(loss)
        opt.step()

        # Use custom learning-rate scheduler-
        global step
        opt.param_groups[0]['lr'] = custom_lr_scheduler.get_lr(step)
        step += 1

        running_corrects = 0.0
        _, predicted = torch.max(pred, 1)
        running_corrects = torch.sum(predicted == y.data)
        train_acc = (running_corrects.double() / len(y)) * 100
        
        # log to Tensorboard (if  installed) by default-
        self.log('train_loss', loss, on_step = False, on_epoch = True)
        self.log('train_acc', train_acc, on_step = False, on_epoch = True)

        return {'loss': loss, 'train_acc': train_acc}


    def on_after_backward(self):
        # example to inspect gradient information in tensorboard-
        # don't make the tf file huge
        if self.trainer.global_step % 25 == 0:
            for layer_name, param in self.named_parameters():
                grad = param.grad
                self.logger.experiment.add_histogram(
                    tag = layer_name, values = grad,
                    global_step = self.trainer.global_step
                    )


    def configure_optimizers(self):
        # optimizer = optim.Adam(params = self.parameters(), lr = 1e-3)
        optimizer = torch.optim.SGD(
            params = self.parameters(), lr = 0.0,
            momentum = 0.9, weight_decay = 5e-4
        )
        return optimizer




model_cnn = LeNet5_MNIST(beta = 1.0)


# Checkpointing is enabled by default to the current working directory. To change the checkpoint
# path pass in-
path_to_ckpt = "/home/amajumdar/Documents/Codes/PyTorch_Lightning/checkpoints/"

# To modify the behavior of checkpointing pass in your own callback-
# DEFAULTS used by the Trainer
checkpoint_callback = ModelCheckpoint(
    dirpath = os.getcwd(),
    # filename = f'LeNet5_{epoch}-{val_acc:.2f}',
    filename = 'LeNet5-mnist-{epoch:02d}-{val_acc:.2f}',
    save_top_k = 1, verbose = True,
    monitor = 'val_acc', mode = 'max',
    # save_weights_only (bool) – if True, then only the model’s weights will be saved.
    # Otherwise, the optimizer states, lr-scheduler states, etc are added in the checkpoint too.
    save_weights_only = False
)

# Learning-rate monitoring-
lr_monitor = LearningRateMonitor(logging_interval = 'step')
# trainer = Trainer(callbacks = [lr_monitor])


# Train the model-
trainer = pl.Trainer(
    accelerator = 'cpu',
    limit_train_batches = 1.0, limit_val_batches = 1.0,
    max_epochs = 35, default_root_dir = path_to_ckpt,
    callbacks = [checkpoint_callback, lr_monitor]
)

The manual optimization does not seem to save any checkpoint!! I found a similar open question here.

Help!