Hi there,
I’m currently trying to train a GAN using PyTorch lightning and I want to implement multiple discriminator updates per generator update but I am unsure of the best practice. Some implementations do multiple discriminator updates within the same batch. Other implementations simply dont update the generator other the course of a handful of batches.
Below is my training step for the GAN:
def training_step(self, batch):
optG, optD = self.optimizers()
# data and real/fake labels
real_data = batch
real_labels = torch.full((real_data.size(0),), 1.0, dtype=torch.float).type_as(
real_data
)
fake_labels = torch.full((real_data.size(0),), 0.0, dtype=torch.float).type_as(
real_data
)
batch_size = real_data.size(0)
# Generate fake-data
fake_data = self.model.generator(batch_size).type_as(real_data) #For my use case batch size is a part of the forward
# Training the generator
self.toggle_optimizer(optG)
optG.zero_grad()
outD_fake = self.model.discriminator(fake_data)
errG = self.criterion(outD_fake, real_labels)
self.manual_backward(errG, retain_graph=True)
optG.step()
self.untoggle_optimizer(optG)
# Training the discriminator
self.toggle_optimizer(optD)
optD.zero_grad()
outD_real = self.model.discriminator(real_data)
outD_fake = self.model.discriminator(fake_data.detach()
errD_real = self.criterion(outD_real, real_labels) # Discriminator real loss
errD_fake = self.criterion(outD_fake, fake_labels) # Discriminator fake loss
errD = (errD_real + errD_fake) / 2
self.manual_backward(errD)
self.log("train_d_loss_step", errD, prog_bar=True)
optD.step()
self.untoggle_optimizer(optD)
One version of the implementation is to do something like the following when updating the generator, in order to have more discriminator updates the generator updates, this occurs other multiple batches. however.
if (batch_idx + 1) % 4 == 0:
optG.step()
Any advice on what implementation to use in this framework would be greatly appreciated, and how to do multiple discriminator updates for the same batch!