How to switch from optimizer during training

this is the code runs with Adam and not with LBFGS. we try to solve PDE equation using DL. the Problem formation part of the code includes the loss functions.

from IPython.display import clear_output
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from torch.autograd import grad
import pytorch_lightning as pl
from argparse import Namespace


# u_t=u_xx
# u(t,0)=0
# u(t,1)=1
# u(0,x)=(2*x)/(1+x^2)

class MyData(Dataset):
    def __init__(self, startX, stopX, startT, stopT, NumberOfSteps):
        super(MyData, self).__init__()
        x = torch.linspace(startX, stopX, NumberOfSteps, dtype=torch.float32, requires_grad=True)
        t = torch.linspace(startT, stopT, NumberOfSteps, dtype=torch.float32, requires_grad=True)
        gx, gt = torch.meshgrid(x, t)
        self.x = gx.contiguous().view(-1, 1)
        self.t = gt.contiguous().view(-1, 1)

    def __len__(self):
        return self.x.shape[0]

    def __getitem__(self, index):
        return self.x[index], self.t[index]


class MyModel(pl.LightningModule):
    def __init__(self, hparams):
        super(MyModel, self).__init__()
        self.hparams = hparams
        self.Grid = None
        self.loss_fn = nn.MSELoss()
        self.bndry = torch.tensor([0., 1.], dtype=torch.float32, requires_grad=True).view(-1, 1).cuda()
        self.fc1 = nn.Linear(2, 20)
        self.fc2 = nn.Linear(20, 20)
        self.fc3 = nn.Linear(20, 20)
        self.fc4 = nn.Linear(20, 20)
        self.fc5 = nn.Linear(20, 20)
        self.fc6 = nn.Linear(20, 20)
        self.fc7 = nn.Linear(20, 1)
        self.T = nn.Tanh()

    def forward(self, X, T):
        x = torch.cat((X, T), dim=1)
        x = self.T(self.fc1(x))
        x = self.T(self.fc2(x))
        x = self.T(self.fc3(x))
        x = self.T(self.fc4(x))
        x = self.T(self.fc5(x))
        x = self.T(self.fc6(x))
        x = self.fc7(x)
        return x

    def prepare_data(self):
        self.Grid = MyData(0, 1, 0, 2, 200)

    def train_dataloader(self):
        return DataLoader(dataset=self.Grid, batch_size=self.hparams.batch_size, shuffle=True)


    def configure_optimizers(self):
      optimizer = optim.LBFGS(self.parameters(), lr=0.01)
      return optimizer

    
    def training_step(self, train_batch, batch_idx):
        x, t = train_batch
        lg, lb, li = self.problem_formulation(x, t, self.bndry)
        loss = lg + lb + li
        return {'loss': loss}

    def backward(self, trainer, loss, optimizer, optimizer_idx):
        loss.backward(retain_graph=True)


    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure, on_tpu=False, using_native_amp=False, using_lbfgs=True):
        # update params
        optimizer.step(second_order_closure)


    def training_epoch_end(self, outputs):
        clear_output(wait=True)
        sum_total_loss = torch.stack([x['loss'] for x in outputs]).sum()
        print('Epoch={}, total_loss={:.3f}'.format(self.current_epoch, sum_total_loss.item()))
        # return {'sum_total_loss': sum_total_loss}

    def on_train_epoch_end(self):
        fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 8))
        self.plot_pcolor(0, 1, 0, 2, 100, self.device, fig, ax)
        self.Plot_InitialCondition(0, 1, 100, self.device, ax)

    def Plot_InitialCondition(self, xStart, xStop, NUM, device, ax):
        x = torch.linspace(xStart, xStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device).view(-1, 1)
        t = torch.zeros_like(x)
        out = self.forward(x, t)
        x = x.detach().cpu()
        out = out.detach().cpu()
        ax[1].plot(x, out)
        plt.ylim(0, 1.2)
        ax[1].axvline(x=1, color='r')
        ax[1].axhline(y=1, color='r')
        plt.show()

    def plot_pcolor(self, xStart, xStop, tStart, tStop, NUM, device, fig, ax):
        x = torch.linspace(xStart, xStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device)
        t = torch.linspace(tStart, tStop, NUM, dtype=torch.float32, requires_grad=True, device=self.device)
        xg, tg = torch.meshgrid(x, t)
        X = xg.clone()
        T = tg.clone()
        X = X.view(-1, 1)
        T = T.view(-1, 1)
        z = self.forward(X, T)
        z = z.view(NUM, -1)
        xg = xg.detach().cpu()
        tg = tg.detach().cpu()
        z = z.detach().cpu()
        j = ax[0].pcolormesh(xg, tg, z, cmap='jet')
        fig.colorbar(j, ax=ax[0])
        plt.colorbar(j)

    def problem_formulation(self, x, t, bndry):
        u = self.forward(x, t)
        u_x = grad(u, x, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)[0]
        u_xx = grad(u_x, x, grad_outputs=torch.ones_like(u_x), create_graph=True, allow_unused=True)[0]
        u_t = grad(u, t, grad_outputs=torch.ones_like(u), create_graph=True, allow_unused=True)[0]
        lossGrid = self.loss_fn(u_xx, u_t)

        bndryExpand = torch.ones_like(t) * bndry[0]
        u0 = self.forward(bndryExpand, t)
        lossB0 = self.loss_fn(u0, torch.zeros_like(u0, dtype=torch.float32))
        bndryExpand = torch.ones_like(t) * bndry[1]
        u1 = self.forward(bndryExpand, t)
        lossB1 = self.loss_fn(u1, torch.ones_like(u1, dtype=torch.float32))

        bndryExpand = torch.ones_like(x) * bndry[0]
        uInit = self.forward(x, bndryExpand)
        lossIni = self.loss_fn(uInit, (2 * x) / (1 + (x ** 2)))

        return lossGrid, lossB0 + lossB1, lossIni


def main(hparams):
    model = MyModel(hparams)
    # tb_logger = pl.loggers.TensorBoardLogger(hparams.TensorBoard_path,name=hparams.TensorBoard_FileName)
    trainer = pl.Trainer(fast_dev_run=hparams.fast_dev_run,
                         max_epochs=hparams.max_epochs,
                         gpus=hparams.gpus,
                         accumulate_grad_batches=hparams.accumulate_grad_batches,
                         # limit_train_batches=hparams.limit_train_batches,
                         progress_bar_refresh_rate=hparams.progress_bar_refresh_rate,
                         #                     logger=tb_logger,
                         #                     logger=False,
                         checkpoint_callback=hparams.checkpoint_callback)
    trainer.fit(model)


if __name__ == '__main__':
    args = {'root': '/content/drive/Shared drives/Poozesh/facades/train',
            'fast_dev_run': False,
            'max_epochs': 50,
            'gpus': 1,
            # 'limit_train_batches':4,
            'accumulate_grad_batches': 5,
            'progress_bar_refresh_rate': 0,
            'checkpoint_callback': False
            'lr': 0.001,
            'batch_size': 1024,
            'num_workers': 1,
            }
    hparams = Namespace(**args)
    main(hparams)