LightningOptimizer¶
- class lightning.pytorch.core.optimizer.LightningOptimizer(optimizer)[source]¶
Bases:
object
This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.
- step(closure=None, **kwargs)[source]¶
Performs a single optimization step (parameter update).
- Parameters:
- Return type:
- Returns:
The output from the step call, which is generally the output of the closure execution.
Example:
# Scenario for a GAN using manual optimization def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... # compute generator loss loss_gen = self.compute_generator_loss(...) # zero_grad needs to be called before backward opt_gen.zero_grad() self.manual_backward(loss_gen) opt_gen.step() # compute discriminator loss loss_dis = self.compute_discriminator_loss(...) # zero_grad needs to be called before backward opt_dis.zero_grad() self.manual_backward(loss_dis) opt_dis.step() # A more advanced example def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... accumulated_grad_batches = batch_idx % 2 == 0 # compute generator loss def closure_gen(): loss_gen = self.compute_generator_loss(...) self.manual_backward(loss_gen) if accumulated_grad_batches: opt_gen.zero_grad() with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): opt_gen.step(closure=closure_gen) def closure_dis(): loss_dis = self.compute_discriminator_loss(...) self.manual_backward(loss_dis) if accumulated_grad_batches: opt_dis.zero_grad() with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis)
- toggle_model(sync_grad=True)[source]¶
This function is just a helper for advanced users.
Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have
requires_grad
set to False.When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.