When I was implementing GAN with lightning, I first train the discriminator and thought to reuse the generated fake image for training generator. So I used .detach() in the first training step like this:
class DCGAN(pl.LightningModule):
...
def training_step(self, batch, batch_idx, optimizer_idx):
imgs, _ = batch
if optimizer_idx == 0:
# Train discriminator
z = torch.randn(imgs.shape[0], self.hparams.latent_dim, device=self.device)
self.generated_imgs = self(z) # generate fake images
disc_real = self.disc(imgs)
disc_fake = self.disc(self.generated_imgs.detach()) # this should have the effect of retaining computational graph
real_loss = self.criterion(disc_real, torch.ones_like(disc_real))
fake_loss = self.criterion(disc_fake, torch.zeros_like(disc_fake))
d_loss = (real_loss + fake_loss) / 2
self.log('loss/disc', d_loss, on_epoch=True, prog_bar=True)
return d_loss
if optimizer_idx >= 1:
# Train generator
disc_fake = self.disc(self.generated_imgs)
g_loss = self.criterion(disc_fake, torch.ones_like(disc_fake))
self.log('loss/gen', g_loss, on_epoch=True, prog_bar=True)
return g_loss
But I get
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I have found the problem, which is that the graph built by (According to @goku, generated_imgs = self(z)
was also released in the first training step even with .detach()
..detach()
is not the point here, but the gradient of generator will not be calculated when optimizer_idx == 0
.)
However if I instead use plain pytorch or lightning manual backprop (https://pytorch-lightning.readthedocs.io/en/latest/optimizers.html#manual-optimization), everything works out fine.
So what happened in the first case?