Converting PyTorch to Lightning code

Hi, I am trying to convert Pytorch to Lightning, and I was wondering if I need to explicitly add the torch.no_grad

Here’s the PyTorch code

@beartype
def extract_output_shapes(
    modules: List[Module],
    model: Module,
    model_input,
    model_kwargs: dict = dict()
):
    shapes = []
    hooks = []

    def hook_fn(_, input, output):
        return shapes.append(output.shape)

    for module in modules:
        hook = module.register_forward_hook(hook_fn)
        hooks.append(hook)

# should I add this line of code?
    with torch.no_grad():
        model(model_input, **model_kwargs)

    for hook in hooks:
        hook.remove()

    return shapes

And here’s the Lightning code, so do I need to add that explicitly?

@beartype
def extract_output_shapes(
    modules: List[L.LightningModule],
    model: L.LightningModule,
    model_input,
    model_kwargs: dict = dict()
):
    shapes = []
    hooks = []

    def hook_fn(module, input, output):
        return shapes.append(output.shape)
    
    for module in modules:
        hook = module.register_forward_hook(hook_fn)
        hooks.append(hook)
    
    model(model_input, **model_kwargs)

    for hook in hooks:
        hook.remove()

    return shapes

Also is this the correct way to register hooks? I see when I use the module.register_forward_hook, it is referencing the torch implementation and not the Lightning implementation.

A LightningModule is also a nn.Module, so it will behave the same. Since you are registering a forward hook, and you never run backward, the torch.no_grad() context you added is probably a good idea, as it will make the forward call faster and use less memory.

1 Like