How to make tensor’s requres_grad=True even in validation_step of pytorch_lightning
class Net(nn.Module):
def __init__(self):
self.net = nn.Linear(10, 1)
def forward(self, x):
out = self.net(x)
return torch.autograd.grad(out, x, create_graph=True, retain_graph=True)
In such a model, with torch.no_grad()
does not work well with However, with torch.no_grad
is probably applied in the validation_step and test_step of pytorch_lightning. Is there any way to make requres_grad=True
in validation_step as well?
I have tried decorating with @torch.enable_grad()
and adapting with torch.enable_grad():
, but it did not work.