How does `LightningOptimizer.zero_grad()` work?

According to the source code in the GitHub repo., there’s no implementation of zero_grad() or any logic for calling Optimizer.zero_grad(). Is it intentional? If so, how does LightningOptimizer.zero_grad() actually work?

Hey @kaparoo

Yeah it’s a bit hard to see but the LightningOptimizer just basically wraps around the given optimizer. Any method that you can call on the original optimizer can also be called on the LightningOptimizer, including zero_grad(). The magic line that makes this possible is this (in the init):

self.__dict__ = {k: v for k, v in optimizer.__dict__.items() if k not in ("step", "__del__")}
1 Like

Aha! Now I’m clear. Thanks!

1 Like