LightningOptimizer
- class lightning.pytorch.core.optimizer.LightningOptimizer(optimizer)[source]
Bases:
object
This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.
Note: The purpose of this wrapper is only to define new methods and redirect the .step() call. The internal state
__dict__
is not kept in sync with the internal state of the original optimizer, but the Trainer never relies on the internal state of the wrapper.- refresh()[source]
Refreshes the
__dict__
so that it matches the internal states in the wrapped optimizer.This is only needed to present the user with an updated view in case they inspect the state of this wrapper.
- Return type:
- step(closure=None, **kwargs)[source]
Performs a single optimization step (parameter update).
- Parameters:
- Return type:
- Returns:
The output from the step call, which is generally the output of the closure execution.
Example:
# Scenario for a GAN using manual optimization def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... # compute generator loss loss_gen = self.compute_generator_loss(...) # zero_grad needs to be called before backward opt_gen.zero_grad() self.manual_backward(loss_gen) opt_gen.step() # compute discriminator loss loss_dis = self.compute_discriminator_loss(...) # zero_grad needs to be called before backward opt_dis.zero_grad() self.manual_backward(loss_dis) opt_dis.step() # A more advanced example def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... accumulated_grad_batches = batch_idx % 2 == 0 # compute generator loss def closure_gen(): loss_gen = self.compute_generator_loss(...) self.manual_backward(loss_gen) if accumulated_grad_batches: opt_gen.zero_grad() with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): opt_gen.step(closure=closure_gen) def closure_dis(): loss_dis = self.compute_discriminator_loss(...) self.manual_backward(loss_dis) if accumulated_grad_batches: opt_dis.zero_grad() with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis)
- toggle_model(sync_grad=True)[source]
This function is just a helper for advanced users.
Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have
requires_grad
set to False.When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.