Shortcuts

LightningOptimizer

class pytorch_lightning.core.optimizer.LightningOptimizer(optimizer)[source]

Bases: object

This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.

step(closure=None, **kwargs)[source]

Performs a single optimization step (parameter update).

Parameters:
  • closure (Optional[Callable[[], Any]]) – An optional optimizer closure.

  • kwargs (Any) – Any additional arguments to the optimizer.step() call.

Return type:

Any

Returns:

The output from the step call, which is generally the output of the closure execution.

Example:

# Scenario for a GAN using manual optimization
def training_step(...):
    opt_gen, opt_dis = self.optimizers()

    ...

    # compute generator loss
    loss_gen = self.compute_generator_loss(...)
    # zero_grad needs to be called before backward
    opt_gen.zero_grad()
    self.manual_backward(loss_gen)
    opt_gen.step()

    # compute discriminator loss
    loss_dis = self.compute_discriminator_loss(...)

    # zero_grad needs to be called before backward
    opt_dis.zero_grad()
    self.manual_backward(loss_dis)
    opt_dis.step()


# A more advanced example
def training_step(self, batch, batch_idx, ...):
    opt_gen, opt_dis = self.optimizers()

    ...
    accumulated_grad_batches = batch_idx % 2 == 0

    # compute generator loss
    def closure_gen():
        loss_gen = self.compute_generator_loss(...)
        self.manual_backward(loss_gen)
        if accumulated_grad_batches:
            opt_gen.zero_grad()

    with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
        opt_gen.step(closure=closure_gen)

    def closure_dis():
        loss_dis = self.compute_discriminator_loss(...)
        self.manual_backward(loss_dis)
        if accumulated_grad_batches:
            opt_dis.zero_grad()

    with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
        opt_dis.step(closure=closure_dis)
toggle_model(sync_grad=True)[source]

This function is just a helper for advanced users.

Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have requires_grad set to False.

When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.

Return type:

Generator[None, None, None]