Hi, I am trying to convert Pytorch to Lightning, and I was wondering if I need to explicitly add the torch.no_grad
Here’s the PyTorch code
@beartype
def extract_output_shapes(
modules: List[Module],
model: Module,
model_input,
model_kwargs: dict = dict()
):
shapes = []
hooks = []
def hook_fn(_, input, output):
return shapes.append(output.shape)
for module in modules:
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
# should I add this line of code?
with torch.no_grad():
model(model_input, **model_kwargs)
for hook in hooks:
hook.remove()
return shapes
And here’s the Lightning code, so do I need to add that explicitly?
@beartype
def extract_output_shapes(
modules: List[L.LightningModule],
model: L.LightningModule,
model_input,
model_kwargs: dict = dict()
):
shapes = []
hooks = []
def hook_fn(module, input, output):
return shapes.append(output.shape)
for module in modules:
hook = module.register_forward_hook(hook_fn)
hooks.append(hook)
model(model_input, **model_kwargs)
for hook in hooks:
hook.remove()
return shapes
Also is this the correct way to register hooks? I see when I use the module.register_forward_hook
, it is referencing the torch implementation and not the Lightning implementation.