Shortcuts

LightningOptimizer

class lightning.pytorch.core.optimizer.LightningOptimizer(optimizer)[source]

Bases: object

This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches.

Note: The purpose of this wrapper is only to define new methods and redirect the .step() call. The internal state __dict__ is not kept in sync with the internal state of the original optimizer, but the Trainer never relies on the internal state of the wrapper.

refresh()[source]

Refreshes the __dict__ so that it matches the internal states in the wrapped optimizer.

This is only needed to present the user with an updated view in case they inspect the state of this wrapper.

Return type:

None

step(closure=None, **kwargs)[source]

Performs a single optimization step (parameter update).

Parameters:
  • closure (Optional[Callable[[], Any]]) – An optional optimizer closure.

  • kwargs (Any) – Any additional arguments to the optimizer.step() call.

Return type:

Any

Returns:

The output from the step call, which is generally the output of the closure execution.

Example:

# Scenario for a GAN using manual optimization
def training_step(self, batch, batch_idx):
    opt_gen, opt_dis = self.optimizers()

    ...

    # compute generator loss
    loss_gen = self.compute_generator_loss(...)
    # zero_grad needs to be called before backward
    opt_gen.zero_grad()
    self.manual_backward(loss_gen)
    opt_gen.step()

    # compute discriminator loss
    loss_dis = self.compute_discriminator_loss(...)

    # zero_grad needs to be called before backward
    opt_dis.zero_grad()
    self.manual_backward(loss_dis)
    opt_dis.step()


# A more advanced example
def training_step(self, batch, batch_idx):
    opt_gen, opt_dis = self.optimizers()

    ...
    accumulated_grad_batches = batch_idx % 2 == 0

    # compute generator loss
    def closure_gen():
        loss_gen = self.compute_generator_loss(...)
        self.manual_backward(loss_gen)
        if accumulated_grad_batches:
            opt_gen.zero_grad()

    with opt_gen.toggle_model(sync_grad=accumulated_grad_batches):
        opt_gen.step(closure=closure_gen)

    def closure_dis():
        loss_dis = self.compute_discriminator_loss(...)
        self.manual_backward(loss_dis)
        if accumulated_grad_batches:
            opt_dis.zero_grad()

    with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
        opt_dis.step(closure=closure_dis)
toggle_model(sync_grad=True)[source]

This function is just a helper for advanced users.

Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have requires_grad set to False.

When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting sync_grad to False will block this synchronization and improve performance.

Return type:

Generator[None, None, None]