I want to train a WGAN-GP model, so I wrote this code that fits in training_step
, validation_step
and testing_step
:
generator_opt, critic_opt = self.optimizers()
random_noise = torch.randn((fg_images.shape[0],) + (self.latent,)).to(fg_images)
self.toggle_optimizer(generator_opt)
fake = self.generator(random_noise)
generator_loss = (self.critic(fake).mean())
if stage == 'train':
self.manual_backward(generator_loss)
generator_opt.step()
generator_opt.zero_grad()
self.untoggle_optimizer(generator_opt)
self.toggle_optimizer(critic_opt)
critic_real = self.critic(fg_images).mean()
critic_loss = -self.critic(fake.detach()).mean() + critic_real
if stage == 'train':
gradient_penalty = 10 * self.gradient_penalty(fg_images.detach(), fake.detach())
self.manual_backward(critic_loss + gradient_penalty)
critic_opt.step()
critic_opt.zero_grad()
self.untoggle_optimizer(critic_opt)
My understanding is that toggle_optimizer will detach any parameters that are not associated with that optimizer being toggled. Our of curiosity, I wrote the following code which is similar to the above, however, I used my own toggle method:
random_noise = torch.randn((fg_images.shape[0],) + (self.latent,)).to(fg_images)
fake = self.generator(random_noise)
freeze(self.critic)
generator_loss = (self.critic(fake).mean())
unfreeze(self.critic)
if stage == 'train':
self.manual_backward(generator_loss)
generator_opt.step()
generator_opt.zero_grad()
critic_real = self.critic(fg_images).mean()
critic_loss = -self.critic(fake.detach()).mean() + critic_real
if stage == 'train':
gradient_penalty = 10 * self.gradient_penalty(fg_images.detach(), fake.detach())
self.manual_backward(critic_loss + gradient_penalty)
critic_opt.step()
critic_opt.zero_grad()
Where freeze
and unfreeze
are defined as:
def freeze(model):
for p in model.parameters():
p.requires_grad_(False)
model.eval()
def unfreeze(model):
for p in model.parameters():
p.requires_grad_(True)
model.train()
The first approach doesn’t work because the generated images are not of good quality, but using the second approach (freeze/unfreeze) works. Why?
Another question: can I toggle multiple optimizers? For instance, I have another NN, let’s call it helper
, which is supposed to work with the generator, so its loss is basically part of the generator loss. To me, it seems that we can only toggle one optimizer at a time.