Multiple Disccriminator network updates during GAN training

Hi there,
I’m currently trying to train a GAN using PyTorch lightning and I want to implement multiple discriminator updates per generator update but I am unsure of the best practice. Some implementations do multiple discriminator updates within the same batch. Other implementations simply dont update the generator other the course of a handful of batches.

Below is my training step for the GAN:

 def training_step(self, batch):

        optG, optD = self.optimizers()

        # data and real/fake labels
        real_data = batch
        
        real_labels = torch.full((real_data.size(0),), 1.0, dtype=torch.float).type_as(
            real_data
        )
        fake_labels = torch.full((real_data.size(0),), 0.0, dtype=torch.float).type_as(
            real_data
        )

        batch_size = real_data.size(0)
        # Generate fake-data
        fake_data = self.model.generator(batch_size).type_as(real_data) #For my use case batch size is a part of the forward
        
        # Training the generator
        self.toggle_optimizer(optG)
        optG.zero_grad()
        outD_fake = self.model.discriminator(fake_data)
        errG = self.criterion(outD_fake, real_labels)
        self.manual_backward(errG, retain_graph=True)
        optG.step()
        self.untoggle_optimizer(optG)

        # Training the discriminator
        self.toggle_optimizer(optD)
        optD.zero_grad()
        outD_real = self.model.discriminator(real_data)
        outD_fake = self.model.discriminator(fake_data.detach()
        errD_real = self.criterion(outD_real, real_labels)  # Discriminator real loss
        errD_fake = self.criterion(outD_fake, fake_labels)  # Discriminator fake loss
        errD = (errD_real + errD_fake) / 2
        self.manual_backward(errD)
        self.log("train_d_loss_step", errD, prog_bar=True)
        optD.step()
        self.untoggle_optimizer(optD)

One version of the implementation is to do something like the following when updating the generator, in order to have more discriminator updates the generator updates, this occurs other multiple batches. however.

if (batch_idx + 1) % 4 == 0:
            optG.step()

Any advice on what implementation to use in this framework would be greatly appreciated, and how to do multiple discriminator updates for the same batch!