Dear Pytorch lightning community,
I am using Pytorch lightning to train a GAN. Thus, for each training_step I have one generator_step and one discriminator_step (similar to lightning-bolts/basic_gan_module.py at f48357be353b7acdd882379ac3308fbec95dc40d · Lightning-AI/lightning-bolts · GitHub ).
I need to do two steps of the discriminator_step, thus I have to backward two times for each batch. What’s the best way to do it in pytorch lighting?
thank you
goku
October 16, 2020, 7:45pm
2
you mean for a single batch you want to do:
# Generator
gen_loss.backward()
gen_opt.step()
gen_opt.zero_grad()
# Discriminator
disc_loss.backward()
disc_opt.step()
disc_opt.step()
disc_opt.zero_grad()
or something else?
teddy
October 16, 2020, 10:20pm
4
See manual optimization . This should allow you to do exactly what you want!
1 Like
goku
October 17, 2020, 1:59pm
5
I think you might get an error here if you use the old weights to calculate disc_loss2
because first disc_opt.step()
will update the weights and disc_loss2.backward()
will be calculating the gradients using the new weights.
1 Like
@goku thank you, it was a wrong example indeed but @teddy solution works!! Thank you all