Resolving backward() with create_graph()

I’m using manual.backward(…,create_graph=True) with a second order optimizer (adahessian from torch-optimizers). I get the warning below, and sure enough, I get the memory leak (running out of memory after 2 epochs). How is this problem resolved in pytorch-lightning and has anyone else run into it.

“UserWarning: Using backward() with create_graph=True will create a reference cycle between the parameter and its gradient which can cause a memory leak. We recommend using autograd.grad when creating the graph to avoid this. If you have to use this function, make sure to reset the .grad fields of your parameters to None after use to break the cycle and avoid the leak.”