Use retain_graph True in a pytorch lightning model

I would like to know the correct way to include retain_graph=True in a pytorch_lightning model. Currently, I am using:

    def on_backward(self, use_amp, loss, optimizer):
        loss.backward(retain_graph=True)

But when I run it still complains retain_graph needs to be True for a successive backward pass. Thanks!

hey @mamunm

the correct hook would be just backward.

Also, we have moved the discussions to GitHub Discussions. You might want to check that out instead to get a quick response. The forums will be marked read-only soon.

Thank you