Manual Optimization and CycleGAN

Hello!

I am trying to implement a cycleGAN but I have a problem with the manual optimization. I am not sure what I am doing wrong. I already read many post and GAN tutorial en Pytorchlightning but still I encounter with the same error all the time.

Is any way to approach this? I try toggle and untoggle, and do not work. Also I tried with no_grad:… also no working. Finally I read somewhere that it is a posible error of the version but I already try 2.0.0 and 2.0.1

Here is the error

Here is my code

class CycleGAN(pl.LightningModule):
  '''
  CycleGAN Class: 
  @Based on the paper: 
   - [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
    "Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
    in IEEE International Conference on Computer Vision (ICCV), 2017
   - [2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2020).
    Generative adversarial networks. Communications of the ACM, 63(11), 139-144.
   - [3] https://arxiv.org/abs/1703.10593
   - [4] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019). 
    Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
    IEEE transactions on medical imaging, 38(10), 2375-2388.
   - [5] https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html
   - [6] https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
    

  @Description
  @Inputs
  @Outputs
  '''
  def __init__(self,
               input,
               params,
               features=64):
    super(CycleGAN,self).__init__()

    ############# Customize class #############
    self.save_hyperparameters(params)
    self.automatic_optimization = False
    self.target_shape=target_shape
    #self.device="cuda"
    self.lr = params["lr"]   
    self.b1 = params["b1"]
    self.b2 = params["b2"]
    self.lbc_T1 = params["lbc_T1"]   
    self.lbc_T2 = params["lbc_T2"]
    self.btch_size = params["batch_size"]
    self.target_shape = params["target_shape"]
    self.lbi=params["lbi"]

    ############# Define components #############
    self.G_T1_T2=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
    self.D_T1=Discriminator(input,HCh=features,n=3)

    self.G_T2_T1=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
    self.D_T2=Discriminator(input,HCh=features,n=3)


    self.G_T1_T2=self.G_T1_T2.apply(self.weights_init)
    self.D_T1=self.D_T1.apply(self.weights_init)
    self.G_T2_T1=self.G_T2_T1.apply(self.weights_init)
    self.D_T2=self.D_T2.apply(self.weights_init)



   ############# Define loss #############
    self.identity_loss = torch.nn.L1Loss()
    self.adv_loss = torch.nn.MSELoss() #adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
    self.cycle_loss = torch.nn.L1Loss()
    

  def forward(self, x):
    x=self.G_T1_T2(x)
    return x

  def training_step(self, batch, batch_idx):
    '''
      @Description:
      - [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
      Compute loss accorgin original implementation of CycleGAN.
    

      @Inputs

      @Outputs
    '''
    ############# Initialization #############
    real_T1, real_T2 = batch
    Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()

    ############# update discriminator #############
    #### Discriminator T1
    self.toggle_optimizer(Dopt_T1)
   
    f_T1 = self.G_T2_T1(real_T2)
    T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
    self.log("D_loss_T1", T1Loss_Dics, prog_bar=True)
    
    self.manual_backward(T1Loss_Dics,retain_graph=True)
    Dopt_T1.step()
    Dopt_T1.zero_grad() # Zero out the gradient before backpropagation
    self.untoggle_optimizer(Dopt_T1)

    #### Discriminator T2
    self.toggle_optimizer(Dopt_T2)
    
    f_T2 = self.G_T1_T2(real_T1)
    T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")
    self.log("D_loss_T2", T2Loss_Dics, prog_bar=True)
    
    self.manual_backward(T2Loss_Dics,retain_graph=True)
    Dopt_T2.step()
    Dopt_T2.zero_grad() # Zero out the gradient before backpropagation
    self.untoggle_optimizer(Dopt_T2)


    ############# update Generator #############
    self.toggle_optimizer(Gopt)
    gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)
    self.log("G_loss", gen_loss, prog_bar=True)
    
    self.manual_backward(gen_loss) # Update gradients
    Gopt.step() # Update optimizer
    Gopt.zero_grad()
    self.untoggle_optimizer(Gopt)
    
    return {'G_loss': gen_loss, 'D_loss_T2': T2Loss_Dics, 'D_loss_T1': T1Loss_Dics, 'identity': Iden_term,'Cycle_term': Cycle_term, "Adver_term":Adv_term}


  def validation_step(self, batch, batch_idx):

    ############# Initialization #############
    real_T1, real_T2 = batch
    Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()

    ############# update discriminator #############
    #### Discriminator T1
    self.toggle_optimizer(Dopt_T1)
    f_T1 = self.G_T2_T1(real_T2)
    
    T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
    Dopt_T1.zero_grad() # Zero out the gradient before backpropagation
    self.manual_backward(T1Loss_Dics,retain_graph=True)
    
    Dopt_T1.step()
    self.untoggle_optimizer(Dopt_T1)

    #### Discriminator T2
    self.toggle_optimizer(Dopt_T2)
    f_T2 = self.G_T1_T2(real_T1)

    T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")
    Dopt_T2.zero_grad() # Zero out the gradient before backpropagation
    self.manual_backward(T2Loss_Dics,retain_graph=True)
    
    Dopt_T2.step()
    self.untoggle_optimizer(Dopt_T2)


    ############# update Generator #############
    self.toggle_optimizer(Gopt)
    gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)
    Gopt.zero_grad()
    self.manual_backward(gen_loss) # Update gradients
    Gopt.step() # Update optimizer
    self.untoggle_optimizer(Gopt)

    ########### Loggers ###########
    self.log("Dval_loss_T1", T1Loss_Dics, prog_bar=True,on_step=True, on_epoch=True, logger=True)
    self.log("Dval_loss_T2", T2Loss_Dics, prog_bar=True,on_step=True, on_epoch=True, logger=True)
    self.log("Gval_loss", gen_loss, prog_bar=True,on_step=True, on_epoch=True, logger=True)


    return {'Gval_loss': gen_loss, 'Dval_loss_T2': T2Loss_Dics, 'Dval_loss_T1': T1Loss_Dics, 'Val_identity': Iden_term,'Val_Cycle_term': Cycle_term, "Val_Adver_term":Adv_term}

  def configure_optimizers(self):
    '''
    @Description:
    - [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix

    Inicialize the optimizers. As it describe in original CycleGAN implementation [1]
    the optimizers are ADAMS. 
  

    @Inputs

    @Outputs
    '''

    lr = self.lr
    b1 = self.b1
    b2 = self.b2

    #Gopt_T1_T2 = torch.optim.Adam(self.G_T1_T2.parameters(), lr=lr, betas=(b1, b2))
    Dopt_T1= torch.optim.Adam(self.D_T1.parameters(), lr=lr, betas=(b1, b2))

    #Gopt_T2_T1 = torch.optim.Adam(self.G_T2_T1.parameters(), lr=lr, betas=(b1, b2))
    Dopt_T2 = torch.optim.Adam(self.D_T2.parameters(), lr=lr, betas=(b1, b2))
    Gopt= torch.optim.Adam(list(self.G_T1_T2.parameters()) + list(self.G_T2_T1.parameters()), lr=lr, betas=(0.5, 0.999))

    return Gopt,Dopt_T1,Dopt_T2
  
  def DiscLoss(self,real,fake,disc="T1"):
    '''
    @Description
    This function computes the discriminator loss using the adversarial loss funtion
    MSE. Taking the target label and the discriminator predictions returns the adversarial loss.
    With adverarial loss from real and from fake image we compute the discriminator loss such as:

    discriminator loss= (adv_fake+adv_real)/2

    @Inputs
    real. Tensor, real image.
    fake. Tensor, fake image.
    
    @Outputs
    Discriminator loss

    '''

    if disc == "T1":
      disc_fake_hat = self.D_T1(fake.detach())      
      disc_real_hat = self.D_T1(real)
    else:
      disc_fake_hat = self.D_T2(fake.detach())
      disc_real_hat = self.D_T2(real)

    fake_loss = self.adv_loss(disc_fake_hat, torch.zeros_like(disc_fake_hat))
    real_loss = self.adv_loss(disc_real_hat, torch.ones_like(disc_real_hat))

    r=(fake_loss + real_loss) / 2
    return r

  def GenLoss(self, real_T1, real_T2):
    '''
    @Description
    @Inputs
    @Outputs
    '''

    #compute fakes
    f_T1 = self.G_T2_T1(real_T2)
    f_T2 = self.G_T1_T2(real_T2)

    #Compute Discriminators output
    dic_f_T1_hat = self.D_T1(f_T1)
    dic_f_T2_hat = self.D_T2(f_T2)

    # Compute adversarial loss AdvLoss_T2_T1 +  AdvLoss_T1_T2
    Adv_term=self.adv_loss(dic_f_T1_hat, torch.ones_like(dic_f_T1_hat)) + self.adv_loss(dic_f_T2_hat, torch.ones_like(dic_f_T2_hat))

    # Compute Cycles
    C_T1 = self.G_T2_T1(f_T2)
    C_T2 = self.G_T1_T2(f_T1)

    # Compute Cycle consistancy. 
    Cycle_term=self.lbc_T1*self.cycle_loss(C_T1,real_T1)+self.lbc_T2*self.cycle_loss(C_T2,real_T2)
        
    #Compute Identities
    identity_T1 = self.G_T2_T1(real_T1)
    identity_T2 = self.G_T1_T2(real_T2)

    # Compute Identity term
    Iden_term =  self.identity_loss (identity_T1, real_T1) + self.identity_loss (identity_T2, real_T2)

    # Compute Total loss
    gen_loss = self.lbi * Iden_term +  Cycle_term + Adv_term


    return gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term

  def weights_init(self,m):
    if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
    if isinstance(m, nn.BatchNorm2d):
        torch.nn.init.normal_(m.weight, 0.0, 0.02)
        torch.nn.init.constant_(m.bias, 0)

I found the reason. I was using self.manual_backward(T1Loss_Dics,retain_graph=True)
in validation_step. There is no need to call manual_backward or no_grad in validation_steps. The functions does it by you

1 Like