TL; DR
When calling LightningModule.manual_backward
from other than LightningModule.training_step
, A RuntimeError
is raised. How can I use manual_backward
in my custom methods?
Environment
OS: Windows 10
Python: 3.10 (Anaconda)
PyTorch: 2.0 (CUDA 11.7)
PyTorch Lightning: 2.0.1
Content
Hello. I’m trying to implement pix2pix. The code below is the training part (see this gist for detail).
class Pix2Pix(LightningModule):
def __init__(self, ...) -> None:
super().__init__()
self.automatic_optimization = False
...
...
def train_discriminator(self, cond_images: Tensor, real_images: Tensor, fake_images: Tensor) -> None:
self.toggle_optimizer(optimizer := self.optimizers()[0])
optimizer.zero_grad()
preds_real = self.discriminator(cond_images, real_images)
preds_fake = self.discriminator(cond_images, fake_images.detach())
loss_real = self.adversarial_loss(preds_real, as_real=True)
loss_fake = self.adversarial_loss(preds_fake, as_real=False)
loss = (loss_real + loss_fake) / 2
self.manual_backward(loss)
optimizer.step()
self.untoggle_optimizer(optimizer)
self.log("d_loss", loss, prog_bar=True)
self.log("d_loss_real", loss_real)
self.log("d_loss_fake", loss_fake)
def train_generator(self, cond_images: Tensor, fake_images: Tensor, real_images: Tensor) -> None:
self.toggle_optimizer(optimizer := self.optimizers()[1])
optimizer.zero_grad()
preds = self.discriminator(cond_images, fake_images)
loss_adv = self.adversarial_loss(preds, as_real=True)
loss_rcn = self.reconstruction_loss(fake_images, real_images)
loss = loss_adv + self.lambda_rcn * loss_rcn
self.manual_backward(loss)
optimizer.step()
self.untoggle_optimizer(optimizer)
self.log("g_loss", loss, prog_bar=True)
self.log("g_loss_adv", loss_adv)
self.log("g_loss_rcn", loss_rcn)
def training_step(self, batch: tuple[Tensor, Tensor]) -> None:
if self.input_first:
cond_images, real_images = batch
else:
real_images, cond_images = batch
fake_images = self.forward(cond_images)
self.train_discriminator(cond_images, fake_images, real_images)
self.train_generator(cond_images, fake_images, real_images)
However, the following error occurred when I tried to train the model using Trainer
:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
Since it was difficult to determine exactly where the problem occurred with only the error message, I did the following two experiments.
-
Train the networks manually in traditional PyTorch fashion (i.e., using for-loop)
It worked, and the fact implies my implementations of self.generator and self.discriminator have no problem. -
Merge the three methods into one
training_step
so that the method executesmanual_backward
for both self.discriminator and self.generator, respectively.
It also worked, and the fact implies at least the functionalities of PyTorch and LightningModule’s APIs have no problem.
Therefore, I can conclude that calling self.manual_backward in both custom functions is the problem. What should I do to solve this? Or have I made a mistake?
BTW, as a Ph.D. student, thank you for creating such a wonderful library.