RuntimeError is thrown when calling `manual_backward` from other than `training_step`

When calling LightningModule.manual_backward from other than LightningModule.training_step, A RuntimeError is raised. How can I use manual_backward in my custom methods?

OS: Windows 10
Python: 3.10 (Anaconda)
PyTorch: 2.0 (CUDA 11.7)
PyTorch Lightning: 2.0.1

Hello. I’m trying to implement pix2pix. The code below is the training part (see this gist for detail).

class Pix2Pix(LightningModule):
    def __init__(self, ...) -> None:
        self.automatic_optimization = False

    def train_discriminator(self, cond_images: Tensor, real_images: Tensor, fake_images: Tensor) -> None:
        self.toggle_optimizer(optimizer := self.optimizers()[0])

        preds_real = self.discriminator(cond_images, real_images)
        preds_fake = self.discriminator(cond_images, fake_images.detach())
        loss_real = self.adversarial_loss(preds_real, as_real=True)
        loss_fake = self.adversarial_loss(preds_fake, as_real=False)
        loss = (loss_real + loss_fake) / 2


        self.log("d_loss", loss, prog_bar=True)
        self.log("d_loss_real", loss_real)
        self.log("d_loss_fake", loss_fake)

    def train_generator(self, cond_images: Tensor, fake_images: Tensor, real_images: Tensor) -> None:
        self.toggle_optimizer(optimizer := self.optimizers()[1])

        preds = self.discriminator(cond_images, fake_images)
        loss_adv = self.adversarial_loss(preds, as_real=True)
        loss_rcn = self.reconstruction_loss(fake_images, real_images)
        loss = loss_adv + self.lambda_rcn * loss_rcn


        self.log("g_loss", loss, prog_bar=True)
        self.log("g_loss_adv", loss_adv)
        self.log("g_loss_rcn", loss_rcn)

    def training_step(self, batch: tuple[Tensor, Tensor]) -> None:
        if self.input_first:
            cond_images, real_images = batch
            real_images, cond_images = batch
        fake_images = self.forward(cond_images)
        self.train_discriminator(cond_images, fake_images, real_images)
        self.train_generator(cond_images, fake_images, real_images)

However, the following error occurred when I tried to train the model using Trainer:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Since it was difficult to determine exactly where the problem occurred with only the error message, I did the following two experiments.

  1. Train the networks manually in traditional PyTorch fashion (i.e., using for-loop)
    It worked, and the fact implies my implementations of self.generator and self.discriminator have no problem.

  2. Merge the three methods into one training_step so that the method executes manual_backward for both self.discriminator and self.generator, respectively.
    It also worked, and the fact implies at least the functionalities of PyTorch and LightningModule’s APIs have no problem.

Therefore, I can conclude that calling self.manual_backward in both custom functions is the problem. What should I do to solve this? Or have I made a mistake?

BTW, as a Ph.D. student, thank you for creating such a wonderful library.

This is because in train_generator, you are using the discriminator. Then you backward through both the generator and discriminator, but only step the optimizer for the discriminator I guess. When you then get to the train_discriminator call, you try to backward as second time through the discriminator.

I see 3 possible solutions here:

  1. Set retain_graph=True in the first backward call as the error message suggests
  2. Flip the order of optimization. Do self.train_generator first and then self.train_discriminator.
  3. Freeze the discriminator (setting it’s weights to requires_grad=False) during train_generator so you don’t backward through it. Don’t forget to unfreeze it after.

I think at least one of these ideas should work.