How to switch from optimizer during training

I tried the following code on MNIST dataset. it seems that the loss is updating for a couple of steps and then stops. I was just wondering if you can check this code and if possible, please share with me the code that you tried on MNIST using LBFGS optimizer.

import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchvision.datasets import MNIST
from torchvision import transforms,datasets
from torch.utils.data import DataLoader,random_split
import pytorch_lightning as pl 
from IPython.display import clear_output

class LightningMNISTClassifier(pl.LightningModule):
  def __init__(self):
    super(LightningMNISTClassifier,self).__init__()
    self.layer_1 = nn.Linear(28 * 28, 128)
    self.layer_2 = nn.Linear(128, 256)
    self.layer_3 = nn.Linear(256, 10)
    
  def forward(self, x):
    batch_size, channels, width, height = x.size()
    x=x.view(batch_size,-1)
    # layer 1
    x = self.layer_1(x)
    x = torch.relu(x)
    # layer 2
    x = self.layer_2(x)
    x = torch.relu(x) 
    # layer 3
    x = self.layer_3(x)
    # probability distribution over labels
    x = torch.log_softmax(x, dim=1)  
    return x 
  def prepare_data(self):
    transform=transforms.Compose([transforms.ToTensor(), 
                                  transforms.Normalize((0.1307,), (0.3081,))])
    # prepare transforms standard to MNIST
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform)
    mnist_test = MNIST(os.getcwd(), train=False, download=True, transform=transform)  
    self.mnist_train, self.mnist_val = random_split(mnist_train, [55000, 5000])

  def train_dataloader(self):
    return DataLoader(self.mnist_train,batch_size=1024)
 
  # def val_dataloader(self):
  #   return DataLoader(self.mnist_val,batch_size=1024)
  # def test_dataloader(self):
  #   return DataLoader(self.mnist_test,batch_size=1024)


  def configure_optimizers(self):
    # optimizer=optim.Adam(self.parameters(),lr=1e-3)
    optimizer = optim.LBFGS(self.parameters(), lr=1e-4)
    return optimizer

  # def backward(self, trainer, loss, optimizer):
  #   loss.backward(retain_graph=True)


  def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx,
                     second_order_closure, on_tpu=False, using_native_amp=False,
                     using_lbfgs=False):
        # update params
      optimizer.step(second_order_closure) 

  def cross_entropy_loss(self,logits,labels):
    return F.nll_loss(logits,labels)

  def training_step(self,train_batch,batch_idx):
    x,y=train_batch
    logits=self.forward(x)
    loss=self.cross_entropy_loss(logits,y)
    return  {'loss':loss}

  def training_epoch_end(self,outputs):
    avg_loss=torch.stack([x['loss'] for x in outputs]).mean()
    print('epoch={}, avg_Train_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
    # return {'avg_train_loss':avg_loss}

  # def validation_step(self,val_batch,batch_idx):
  #   x,y=val_batch
  #   logits=self.forward(x)
  #   loss=self.cross_entropy_loss(logits,y)
  #   return {'val_loss':loss}
  # def validation_epoch_end(self,outputs):
  #   avg_loss=torch.stack([x['val_loss'] for x in outputs]).mean()
  #   print('epoch={}, avg_Test_loss={:.2f}'.format(self.current_epoch,avg_loss.item()))
  #   return {'avg_val_loss':avg_loss}

model=LightningMNISTClassifier()
#from pytorch_lightning.callbacks import EarlyStopping
trainer=pl.Trainer(max_epochs=400,gpus=1,
                  #  check_val_every_n_epoch=2,
                  #  accumulate_grad_batches=5,
#                   early_stop_callback=early_stop,
                  #  limit_train_batches=50,
#                   val_check_interval=0.25,
                   progress_bar_refresh_rate=0,
#                   num_sanity_val_steps=0,
                   weights_summary=None)
clear_output(wait=True)
trainer.fit(model)Preformatted text