Hello!
I am trying to implement a cycleGAN but I have a problem with the manual optimization. I am not sure what I am doing wrong. I already read many post and GAN tutorial en Pytorchlightning but still I encounter with the same error all the time.
Is any way to approach this? I try toggle and untoggle, and do not work. Also I tried with no_grad:… also no working. Finally I read somewhere that it is a posible error of the version but I already try 2.0.0 and 2.0.1
Here is the error
Here is my code
class CycleGAN(pl.LightningModule):
'''
CycleGAN Class:
@Based on the paper:
- [1] Jun-Yan Zhu*, Taesung Park*, Phillip Isola, and Alexei A. Efros.
"Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks",
in IEEE International Conference on Computer Vision (ICCV), 2017
- [2] Goodfellow, I., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., ... & Bengio, Y. (2020).
Generative adversarial networks. Communications of the ACM, 63(11), 139-144.
- [3] https://arxiv.org/abs/1703.10593
- [4] Dar, S. U., Yurt, M., Karacan, L., Erdem, A., Erdem, E., & Cukur, T. (2019).
Image synthesis in multi-contrast MRI with conditional generative adversarial networks.
IEEE transactions on medical imaging, 38(10), 2375-2388.
- [5] https://lightning.ai/docs/pytorch/stable/notebooks/lightning_examples/basic-gan.html
- [6] https://www.assemblyai.com/blog/pytorch-lightning-for-dummies/
@Description
@Inputs
@Outputs
'''
def __init__(self,
input,
params,
features=64):
super(CycleGAN,self).__init__()
############# Customize class #############
self.save_hyperparameters(params)
self.automatic_optimization = False
self.target_shape=target_shape
#self.device="cuda"
self.lr = params["lr"]
self.b1 = params["b1"]
self.b2 = params["b2"]
self.lbc_T1 = params["lbc_T1"]
self.lbc_T2 = params["lbc_T2"]
self.btch_size = params["batch_size"]
self.target_shape = params["target_shape"]
self.lbi=params["lbi"]
############# Define components #############
self.G_T1_T2=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
self.D_T1=Discriminator(input,HCh=features,n=3)
self.G_T2_T1=Generator(input,out_f=features,lvl_aut=2,lvl_resnt=9)
self.D_T2=Discriminator(input,HCh=features,n=3)
self.G_T1_T2=self.G_T1_T2.apply(self.weights_init)
self.D_T1=self.D_T1.apply(self.weights_init)
self.G_T2_T1=self.G_T2_T1.apply(self.weights_init)
self.D_T2=self.D_T2.apply(self.weights_init)
############# Define loss #############
self.identity_loss = torch.nn.L1Loss()
self.adv_loss = torch.nn.MSELoss() #adversarial loss function to keep track of how well the GAN is fooling the discriminator and how well the discriminator is catching the GAN
self.cycle_loss = torch.nn.L1Loss()
def forward(self, x):
x=self.G_T1_T2(x)
return x
def training_step(self, batch, batch_idx):
'''
@Description:
- [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Compute loss accorgin original implementation of CycleGAN.
@Inputs
@Outputs
'''
############# Initialization #############
real_T1, real_T2 = batch
Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()
############# update discriminator #############
#### Discriminator T1
self.toggle_optimizer(Dopt_T1)
f_T1 = self.G_T2_T1(real_T2)
T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
self.log("D_loss_T1", T1Loss_Dics, prog_bar=True)
self.manual_backward(T1Loss_Dics,retain_graph=True)
Dopt_T1.step()
Dopt_T1.zero_grad() # Zero out the gradient before backpropagation
self.untoggle_optimizer(Dopt_T1)
#### Discriminator T2
self.toggle_optimizer(Dopt_T2)
f_T2 = self.G_T1_T2(real_T1)
T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")
self.log("D_loss_T2", T2Loss_Dics, prog_bar=True)
self.manual_backward(T2Loss_Dics,retain_graph=True)
Dopt_T2.step()
Dopt_T2.zero_grad() # Zero out the gradient before backpropagation
self.untoggle_optimizer(Dopt_T2)
############# update Generator #############
self.toggle_optimizer(Gopt)
gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)
self.log("G_loss", gen_loss, prog_bar=True)
self.manual_backward(gen_loss) # Update gradients
Gopt.step() # Update optimizer
Gopt.zero_grad()
self.untoggle_optimizer(Gopt)
return {'G_loss': gen_loss, 'D_loss_T2': T2Loss_Dics, 'D_loss_T1': T1Loss_Dics, 'identity': Iden_term,'Cycle_term': Cycle_term, "Adver_term":Adv_term}
def validation_step(self, batch, batch_idx):
############# Initialization #############
real_T1, real_T2 = batch
Gopt,Dopt_T1,Dopt_T2=self.configure_optimizers()
############# update discriminator #############
#### Discriminator T1
self.toggle_optimizer(Dopt_T1)
f_T1 = self.G_T2_T1(real_T2)
T1Loss_Dics = self.DiscLoss(real_T1,f_T1,disc="T1")
Dopt_T1.zero_grad() # Zero out the gradient before backpropagation
self.manual_backward(T1Loss_Dics,retain_graph=True)
Dopt_T1.step()
self.untoggle_optimizer(Dopt_T1)
#### Discriminator T2
self.toggle_optimizer(Dopt_T2)
f_T2 = self.G_T1_T2(real_T1)
T2Loss_Dics = self.DiscLoss(real_T2,f_T2,disc="T2")
Dopt_T2.zero_grad() # Zero out the gradient before backpropagation
self.manual_backward(T2Loss_Dics,retain_graph=True)
Dopt_T2.step()
self.untoggle_optimizer(Dopt_T2)
############# update Generator #############
self.toggle_optimizer(Gopt)
gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term = self.GenLoss(real_T1, real_T2)
Gopt.zero_grad()
self.manual_backward(gen_loss) # Update gradients
Gopt.step() # Update optimizer
self.untoggle_optimizer(Gopt)
########### Loggers ###########
self.log("Dval_loss_T1", T1Loss_Dics, prog_bar=True,on_step=True, on_epoch=True, logger=True)
self.log("Dval_loss_T2", T2Loss_Dics, prog_bar=True,on_step=True, on_epoch=True, logger=True)
self.log("Gval_loss", gen_loss, prog_bar=True,on_step=True, on_epoch=True, logger=True)
return {'Gval_loss': gen_loss, 'Dval_loss_T2': T2Loss_Dics, 'Dval_loss_T1': T1Loss_Dics, 'Val_identity': Iden_term,'Val_Cycle_term': Cycle_term, "Val_Adver_term":Adv_term}
def configure_optimizers(self):
'''
@Description:
- [1] https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
Inicialize the optimizers. As it describe in original CycleGAN implementation [1]
the optimizers are ADAMS.
@Inputs
@Outputs
'''
lr = self.lr
b1 = self.b1
b2 = self.b2
#Gopt_T1_T2 = torch.optim.Adam(self.G_T1_T2.parameters(), lr=lr, betas=(b1, b2))
Dopt_T1= torch.optim.Adam(self.D_T1.parameters(), lr=lr, betas=(b1, b2))
#Gopt_T2_T1 = torch.optim.Adam(self.G_T2_T1.parameters(), lr=lr, betas=(b1, b2))
Dopt_T2 = torch.optim.Adam(self.D_T2.parameters(), lr=lr, betas=(b1, b2))
Gopt= torch.optim.Adam(list(self.G_T1_T2.parameters()) + list(self.G_T2_T1.parameters()), lr=lr, betas=(0.5, 0.999))
return Gopt,Dopt_T1,Dopt_T2
def DiscLoss(self,real,fake,disc="T1"):
'''
@Description
This function computes the discriminator loss using the adversarial loss funtion
MSE. Taking the target label and the discriminator predictions returns the adversarial loss.
With adverarial loss from real and from fake image we compute the discriminator loss such as:
discriminator loss= (adv_fake+adv_real)/2
@Inputs
real. Tensor, real image.
fake. Tensor, fake image.
@Outputs
Discriminator loss
'''
if disc == "T1":
disc_fake_hat = self.D_T1(fake.detach())
disc_real_hat = self.D_T1(real)
else:
disc_fake_hat = self.D_T2(fake.detach())
disc_real_hat = self.D_T2(real)
fake_loss = self.adv_loss(disc_fake_hat, torch.zeros_like(disc_fake_hat))
real_loss = self.adv_loss(disc_real_hat, torch.ones_like(disc_real_hat))
r=(fake_loss + real_loss) / 2
return r
def GenLoss(self, real_T1, real_T2):
'''
@Description
@Inputs
@Outputs
'''
#compute fakes
f_T1 = self.G_T2_T1(real_T2)
f_T2 = self.G_T1_T2(real_T2)
#Compute Discriminators output
dic_f_T1_hat = self.D_T1(f_T1)
dic_f_T2_hat = self.D_T2(f_T2)
# Compute adversarial loss AdvLoss_T2_T1 + AdvLoss_T1_T2
Adv_term=self.adv_loss(dic_f_T1_hat, torch.ones_like(dic_f_T1_hat)) + self.adv_loss(dic_f_T2_hat, torch.ones_like(dic_f_T2_hat))
# Compute Cycles
C_T1 = self.G_T2_T1(f_T2)
C_T2 = self.G_T1_T2(f_T1)
# Compute Cycle consistancy.
Cycle_term=self.lbc_T1*self.cycle_loss(C_T1,real_T1)+self.lbc_T2*self.cycle_loss(C_T2,real_T2)
#Compute Identities
identity_T1 = self.G_T2_T1(real_T1)
identity_T2 = self.G_T1_T2(real_T2)
# Compute Identity term
Iden_term = self.identity_loss (identity_T1, real_T1) + self.identity_loss (identity_T2, real_T2)
# Compute Total loss
gen_loss = self.lbi * Iden_term + Cycle_term + Adv_term
return gen_loss, f_T1, f_T2,Iden_term,Cycle_term,Adv_term
def weights_init(self,m):
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
if isinstance(m, nn.BatchNorm2d):
torch.nn.init.normal_(m.weight, 0.0, 0.02)
torch.nn.init.constant_(m.bias, 0)