Optimization

Lightning offers two modes for managing the optimization process:

  • Manual Optimization

  • Automatic Optimization

For the majority of research cases, automatic optimization will do the right thing for you and it is what most users should use.

For more advanced use cases like multiple optimizers, esoteric optimization schedules or techniques, use manual optimization.


Manual Optimization

For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process, especially when dealing with multiple optimizers at the same time.

In this mode, Lightning will handle only accelerator, precision and strategy logic. The users are left with optimizer.zero_grad(), gradient accumulation, optimizer toggling, etc..

To manually optimize, do the following:

  • Set self.automatic_optimization=False in your LightningModule’s __init__.

  • Use the following functions and call them manually:

    • self.optimizers() to access your optimizers (one or multiple)

    • optimizer.zero_grad() to clear the gradients from the previous training step

    • self.manual_backward(loss) instead of loss.backward()

    • optimizer.step() to update your model parameters

    • self.toggle_optimizer() and self.untoggle_optimizer() if needed

Here is a minimal example of manual optimization.

from lightning.pytorch import LightningModule


class MyModel(LightningModule):
    def __init__(self):
        super().__init__()
        # Important: This property activates manual optimization.
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()
        opt.zero_grad()
        loss = self.compute_loss(batch)
        self.manual_backward(loss)
        opt.step()

Tip

Be careful where you call optimizer.zero_grad(), or your model won’t converge. It is good practice to call optimizer.zero_grad() before self.manual_backward(loss).

Access your Own Optimizer

The provided optimizer is a LightningOptimizer object wrapping your own optimizer configured in your configure_optimizers(). You can access your own optimizer with optimizer.optimizer. However, if you use your own optimizer to perform a step, Lightning won’t be able to support accelerators, precision and profiling for you.

class Model(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False
        ...

    def training_step(self, batch, batch_idx):
        optimizer = self.optimizers()

        # `optimizer` is a `LightningOptimizer` wrapping the optimizer.
        # To access it, do the following.
        # However, it won't work on TPU, AMP, etc...
        optimizer = optimizer.optimizer
        ...

Gradient Accumulation

You can accumulate gradients over batches similarly to accumulate_grad_batches argument in Trainer for automatic optimization. To perform gradient accumulation with one optimizer after every N steps, you can do as such.

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


def training_step(self, batch, batch_idx):
    opt = self.optimizers()

    # scale losses by 1/N (for N batches of gradient accumulation)
    loss = self.compute_loss(batch) / N
    self.manual_backward(loss)

    # accumulate gradients of N batches
    if (batch_idx + 1) % N == 0:
        opt.step()
        opt.zero_grad()

Gradient Clipping

You can clip optimizer gradients during manual optimization similar to passing the gradient_clip_val and gradient_clip_algorithm argument in Trainer during automatic optimization. To perform gradient clipping with one optimizer with manual optimization, you can do as such.

from lightning.pytorch import LightningModule


class SimpleModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.automatic_optimization = False

    def training_step(self, batch, batch_idx):
        opt = self.optimizers()

        # compute loss
        loss = self.compute_loss(batch)

        opt.zero_grad()
        self.manual_backward(loss)

        # clip gradients
        self.clip_gradients(opt, gradient_clip_val=0.5, gradient_clip_algorithm="norm")

        opt.step()

Warning

  • Note that configure_gradient_clipping() won’t be called in Manual Optimization. Instead consider using self. clip_gradients() manually like in the example above.

Use Multiple Optimizers (like GANs)

Here is an example training a simple GAN with multiple optimizers using manual optimization.

import torch
from torch import Tensor
from lightning.pytorch import LightningModule


class SimpleGAN(LightningModule):
    def __init__(self):
        super().__init__()
        self.G = Generator()
        self.D = Discriminator()

        # Important: This property activates manual optimization.
        self.automatic_optimization = False

    def sample_z(self, n) -> Tensor:
        sample = self._Z.sample((n,))
        return sample

    def sample_G(self, n) -> Tensor:
        z = self.sample_z(n)
        return self.G(z)

    def training_step(self, batch, batch_idx):
        # Implementation follows the PyTorch tutorial:
        # https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
        g_opt, d_opt = self.optimizers()

        X, _ = batch
        batch_size = X.shape[0]

        real_label = torch.ones((batch_size, 1), device=self.device)
        fake_label = torch.zeros((batch_size, 1), device=self.device)

        g_X = self.sample_G(batch_size)

        ##########################
        # Optimize Discriminator #
        ##########################
        d_x = self.D(X)
        errD_real = self.criterion(d_x, real_label)

        d_z = self.D(g_X.detach())
        errD_fake = self.criterion(d_z, fake_label)

        errD = errD_real + errD_fake

        d_opt.zero_grad()
        self.manual_backward(errD)
        d_opt.step()

        ######################
        # Optimize Generator #
        ######################
        d_z = self.D(g_X)
        errG = self.criterion(d_z, real_label)

        g_opt.zero_grad()
        self.manual_backward(errG)
        g_opt.step()

        self.log_dict({"g_loss": errG, "d_loss": errD}, prog_bar=True)

    def configure_optimizers(self):
        g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
        d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
        return g_opt, d_opt

Learning Rate Scheduling

Every optimizer you use can be paired with any Learning Rate Scheduler. Please see the documentation of configure_optimizers() for all the available options

You can call lr_scheduler.step() at arbitrary intervals. Use self.lr_schedulers() in your LightningModule to access any learning rate schedulers defined in your configure_optimizers().

Warning

  • lr_scheduler.step() can be called at arbitrary intervals by the user in case of manual optimization, or by Lightning if "interval" is defined in configure_optimizers() in case of automatic optimization.

  • Note that the lr_scheduler_config keys, such as "frequency" and "interval", will be ignored even if they are provided in your configure_optimizers() during manual optimization.

Here is an example calling lr_scheduler.step() every step.

# step every batch
def __init__(self):
    super().__init__()
    self.automatic_optimization = False


def training_step(self, batch, batch_idx):
    # do forward, backward, and optimization
    ...

    # single scheduler
    sch = self.lr_schedulers()
    sch.step()

    # multiple schedulers
    sch1, sch2 = self.lr_schedulers()
    sch1.step()
    sch2.step()

If you want to call lr_scheduler.step() every N steps/epochs, do the following.

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


def training_step(self, batch, batch_idx):
    # do forward, backward, and optimization
    ...

    sch = self.lr_schedulers()

    # step every N batches
    if (batch_idx + 1) % N == 0:
        sch.step()

    # step every N epochs
    if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % N == 0:
        sch.step()

If you want to call schedulers that require a metric value after each epoch, consider doing the following:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


def on_train_epoch_end(self):
    sch = self.lr_schedulers()

    # If the selected scheduler is a ReduceLROnPlateau scheduler.
    if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
        sch.step(self.trainer.callback_metrics["loss"])

Optimizer Steps at Different Frequencies

In manual optimization, you are free to step() one optimizer more often than another one. For example, here we step the optimizer for the discriminator weights twice as often as the optimizer for the generator.

# Alternating schedule for optimizer steps (e.g. GANs)
def training_step(self, batch, batch_idx):
    g_opt, d_opt = self.optimizers()
    ...

    # update discriminator every other step
    d_opt.zero_grad()
    self.manual_backward(errD)
    if (batch_idx + 1) % 2 == 0:
        d_opt.step()

    ...

    # update generator every step
    g_opt.zero_grad()
    self.manual_backward(errG)
    g_opt.step()

Use Closure for LBFGS-like Optimizers

It is a good practice to provide the optimizer with a closure function that performs a forward, zero_grad and backward of your model. It is optional for most optimizers, but makes your code compatible if you switch to an optimizer which requires a closure, such as LBFGS.

See the PyTorch docs for more about the closure.

Here is an example using a closure function.

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


def configure_optimizers(self):
    return torch.optim.LBFGS(...)


def training_step(self, batch, batch_idx):
    opt = self.optimizers()

    def closure():
        loss = self.compute_loss(batch)
        opt.zero_grad()
        self.manual_backward(loss)
        return loss

    opt.step(closure=closure)

Warning

The LBFGS optimizer is not supported for AMP or DeepSpeed.


Automatic Optimization

With Lightning, most users don’t have to think about when to call .zero_grad(), .backward() and .step() since Lightning automates that for you.

Under the hood, Lightning does the following:

for epoch in epochs:
    for batch in data:

        def closure():
            loss = model.training_step(batch, batch_idx)
            optimizer.zero_grad()
            loss.backward()
            return loss

        optimizer.step(closure)

    lr_scheduler.step()

As can be seen in the code snippet above, Lightning defines a closure with training_step(), optimizer.zero_grad() and loss.backward() for the optimization. This mechanism is in place to support optimizers which operate on the output of the closure (e.g. the loss) or need to call the closure several times (e.g. LBFGS).

Should you still require the flexibility of calling .zero_grad(), .backward(), or .step() yourself, you can always switch to manual optimization. Manual optimization is required if you wish to work with multiple optimizers.

Gradient Accumulation

Accumulated gradients run K small batches of size N before doing a backward pass. The effect is a large effective batch size of size KxN, where N is the batch size. Internally it doesn’t stack up the batches and do a forward pass rather it accumulates the gradients for K batches and then do an optimizer.step to make sure the effective batch size is increased but there is no memory overhead.

Warning

When using distributed training for eg. DDP, with let’s say with P devices, each device accumulates independently i.e. it stores the gradients after each loss.backward() and doesn’t sync the gradients across the devices until we call optimizer.step(). So for each accumulation step, the effective batch size on each device will remain N*K but right before the optimizer.step(), the gradient sync will make the effective batch size as P*N*K. For DP, since the batch is split across devices, the final effective batch size will be N*K.

# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)

# Accumulate gradients for 7 batches
trainer = Trainer(accumulate_grad_batches=7)

Optionally, you can make the accumulate_grad_batches value change over time by using the GradientAccumulationScheduler. Pass in a scheduling dictionary, where the key represents the epoch at which the value for gradient accumulation should be updated.

from lightning.pytorch.callbacks import GradientAccumulationScheduler

# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = Trainer(callbacks=accumulator)

Note: Not all strategies and accelerators support variable gradient accumulation windows.

Access your Own Optimizer

The provided optimizer is a LightningOptimizer object wrapping your own optimizer configured in your configure_optimizers(). You can access your own optimizer with optimizer.optimizer. However, if you use your own optimizer to perform a step, Lightning won’t be able to support accelerators, precision and profiling for you.

# function hook in LightningModule
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_closure,
):
    optimizer.step(closure=optimizer_closure)


# `optimizer` is a `LightningOptimizer` wrapping the optimizer.
# To access it, do the following.
# However, it won't work on TPU, AMP, etc...
def optimizer_step(
    self,
    epoch,
    batch_idx,
    optimizer,
    optimizer_closure,
):
    optimizer = optimizer.optimizer
    optimizer.step(closure=optimizer_closure)

Bring your own Custom Learning Rate Schedulers

Lightning allows using custom learning rate schedulers that aren’t available in PyTorch natively. One good example is Timm Schedulers. When using custom learning rate schedulers relying on a different API from Native PyTorch ones, you should override the lr_scheduler_step() with your desired logic. If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default.

from timm.scheduler import TanhLRScheduler


def configure_optimizers(self):
    optimizer = ...
    scheduler = TanhLRScheduler(optimizer, ...)
    return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]


def lr_scheduler_step(self, scheduler, metric):
    scheduler.step(epoch=self.current_epoch)  # timm's scheduler need the epoch value

Configure Gradient Clipping

To configure custom gradient clipping, consider overriding the configure_gradient_clipping() method. The attributes gradient_clip_val and gradient_clip_algorithm from Trainer will be passed in the respective arguments here and Lightning will handle gradient clipping for you. In case you want to set different values for your arguments of your choice and let Lightning handle the gradient clipping, you can use the inbuilt clip_gradients() method and pass the arguments along with your optimizer.

Warning

Make sure to not override clip_gradients() method. If you want to customize gradient clipping, consider using configure_gradient_clipping() method.

For example, here we will apply a stronger gradient clipping after a certain number of epochs:

def configure_gradient_clipping(self, optimizer, gradient_clip_val, gradient_clip_algorithm):
    if self.current_epoch > 5:
        gradient_clip_val = gradient_clip_val * 2

    # Lightning will handle the gradient clipping
    self.clip_gradients(optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm)

Total Stepping Batches

You can use built-in trainer property estimated_stepping_batches to compute total number of stepping batches for the complete training. The property is computed considering gradient accumulation factor and distributed setting into consideration so you don’t have to derive it manually. One good example where this can be helpful is while using OneCycleLR scheduler, which requires pre-computed total_steps during initialization.

def configure_optimizers(self):
    optimizer = ...
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
    )
    return optimizer, scheduler