I would like to know the correct way to include retain_graph=True
in a pytorch_lightning model. Currently, I am using:
def on_backward(self, use_amp, loss, optimizer):
loss.backward(retain_graph=True)
But when I run it still complains retain_graph needs to be True for a successive backward pass. Thanks!