Optimization¶
Lightning offers two modes for managing the optimization process:
Manual Optimization
Automatic Optimization
For the majority of research cases, automatic optimization will do the right thing for you and it is what most users should use.
For advanced/expert users who want to do esoteric optimization schedules or techniques, use manual optimization.
Manual Optimization¶
For advanced research topics like reinforcement learning, sparse coding, or GAN research, it may be desirable to manually manage the optimization process.
This is only recommended for experts who need ultimate flexibility.
Lightning will handle only accelerator, precision and strategy logic.
The users are left with optimizer.zero_grad()
, gradient accumulation, model toggling, etc..
To manually optimize, do the following:
Set
self.automatic_optimization=False
in yourLightningModule
’s__init__
.Use the following functions and call them manually:
self.optimizers()
to access your optimizers (one or multiple)optimizer.zero_grad()
to clear the gradients from the previous training stepself.manual_backward(loss)
instead ofloss.backward()
optimizer.step()
to update your model parameters
Here is a minimal example of manual optimization.
from pytorch_lightning import LightningModule
class MyModel(LightningModule):
def __init__(self):
super().__init__()
# Important: This property activates manual optimization.
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt = self.optimizers()
opt.zero_grad()
loss = self.compute_loss(batch)
self.manual_backward(loss)
opt.step()
Warning
Before 1.2, optimizer.step()
was calling optimizer.zero_grad()
internally.
From 1.2, it is left to the user’s expertise.
Tip
Be careful where you call optimizer.zero_grad()
, or your model won’t converge.
It is good practice to call optimizer.zero_grad()
before self.manual_backward(loss)
.
Access your Own Optimizer¶
The provided optimizer
is a LightningOptimizer
object wrapping your own optimizer
configured in your configure_optimizers()
. You can access your own optimizer
with optimizer.optimizer
. However, if you use your own optimizer to perform a step, Lightning won’t be able to
support accelerators, precision and profiling for you.
class Model(LightningModule):
def __init__(self):
super().__init__()
self.automatic_optimization = False
...
def training_step(self, batch, batch_idx):
optimizer = self.optimizers()
# `optimizer` is a `LightningOptimizer` wrapping the optimizer.
# To access it, do the following.
# However, it won't work on TPU, AMP, etc...
optimizer = optimizer.optimizer
...
Gradient Accumulation¶
You can accumulate gradients over batches similarly to accumulate_grad_batches
argument in
Trainer for automatic optimization. To perform gradient accumulation with one optimizer
after every N
steps, you can do as such.
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
opt = self.optimizers()
loss = self.compute_loss(batch)
self.manual_backward(loss)
# accumulate gradients of N batches
if (batch_idx + 1) % N == 0:
opt.step()
opt.zero_grad()
Use Multiple Optimizers (like GANs)¶
Here is an example training a simple GAN with multiple optimizers using manual optimization.
import torch
from torch import Tensor
from pytorch_lightning import LightningModule
class SimpleGAN(LightningModule):
def __init__(self):
super().__init__()
self.G = Generator()
self.D = Discriminator()
# Important: This property activates manual optimization.
self.automatic_optimization = False
def sample_z(self, n) -> Tensor:
sample = self._Z.sample((n,))
return sample
def sample_G(self, n) -> Tensor:
z = self.sample_z(n)
return self.G(z)
def training_step(self, batch, batch_idx):
# Implementation follows the PyTorch tutorial:
# https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html
g_opt, d_opt = self.optimizers()
X, _ = batch
batch_size = X.shape[0]
real_label = torch.ones((batch_size, 1), device=self.device)
fake_label = torch.zeros((batch_size, 1), device=self.device)
g_X = self.sample_G(batch_size)
##########################
# Optimize Discriminator #
##########################
d_x = self.D(X)
errD_real = self.criterion(d_x, real_label)
d_z = self.D(g_X.detach())
errD_fake = self.criterion(d_z, fake_label)
errD = errD_real + errD_fake
d_opt.zero_grad()
self.manual_backward(errD)
d_opt.step()
######################
# Optimize Generator #
######################
d_z = self.D(g_X)
errG = self.criterion(d_z, real_label)
g_opt.zero_grad()
self.manual_backward(errG)
g_opt.step()
self.log_dict({"g_loss": errG, "d_loss": errD}, prog_bar=True)
def configure_optimizers(self):
g_opt = torch.optim.Adam(self.G.parameters(), lr=1e-5)
d_opt = torch.optim.Adam(self.D.parameters(), lr=1e-5)
return g_opt, d_opt
Learning Rate Scheduling¶
Every optimizer you use can be paired with any
Learning Rate Scheduler. Please see the
documentation of configure_optimizers()
for all the available options
You can call lr_scheduler.step()
at arbitrary intervals.
Use self.lr_schedulers()
in your LightningModule
to access any learning rate schedulers
defined in your configure_optimizers()
.
Warning
lr_scheduler.step()
can be called at arbitrary intervals by the user in case of manual optimization, or by Lightning if"interval"
is defined inconfigure_optimizers()
in case of automatic optimization.Note that the
lr_scheduler_config
keys, such as"frequency"
and"interval"
, will be ignored even if they are provided in yourconfigure_optimizers()
during manual optimization.
Here is an example calling lr_scheduler.step()
every step.
# step every batch
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
# do forward, backward, and optimization
...
# single scheduler
sch = self.lr_schedulers()
sch.step()
# multiple schedulers
sch1, sch2 = self.lr_schedulers()
sch1.step()
sch2.step()
If you want to call lr_scheduler.step()
every N
steps/epochs, do the following.
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_step(self, batch, batch_idx):
# do forward, backward, and optimization
...
sch = self.lr_schedulers()
# step every N batches
if (batch_idx + 1) % N == 0:
sch.step()
# step every N epochs
if self.trainer.is_last_batch and (self.trainer.current_epoch + 1) % N == 0:
sch.step()
If you want to call schedulers that require a metric value after each epoch, consider doing the following:
def __init__(self):
super().__init__()
self.automatic_optimization = False
def training_epoch_end(self, outputs):
sch = self.lr_schedulers()
# If the selected scheduler is a ReduceLROnPlateau scheduler.
if isinstance(sch, torch.optim.lr_scheduler.ReduceLROnPlateau):
sch.step(self.trainer.callback_metrics["loss"])
Use Closure for LBFGS-like Optimizers¶
It is a good practice to provide the optimizer with a closure function that performs a forward
, zero_grad
and
backward
of your model. It is optional for most optimizers, but makes your code compatible if you switch to an
optimizer which requires a closure, such as LBFGS
.
See the PyTorch docs for more about the closure.
Here is an example using a closure function.
def __init__(self):
super().__init__()
self.automatic_optimization = False
def configure_optimizers(self):
return torch.optim.LBFGS(...)
def training_step(self, batch, batch_idx):
opt = self.optimizers()
def closure():
loss = self.compute_loss(batch)
opt.zero_grad()
self.manual_backward(loss)
return loss
opt.step(closure=closure)
Warning
The LBFGS
optimizer is not supported for apex AMP, native AMP, IPUs, or DeepSpeed.
Automatic Optimization¶
With Lightning, most users don’t have to think about when to call .zero_grad()
, .backward()
and .step()
since Lightning automates that for you.
Under the hood, Lightning does the following:
for epoch in epochs:
for batch in data:
def closure():
loss = model.training_step(batch, batch_idx, ...)
optimizer.zero_grad()
loss.backward()
return loss
optimizer.step(closure)
lr_scheduler.step()
In the case of multiple optimizers, Lightning does the following:
for epoch in epochs:
for batch in data:
for opt in optimizers:
def closure():
loss = model.training_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
return loss
opt.step(closure)
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
As can be seen in the code snippet above, Lightning defines a closure with training_step()
, optimizer.zero_grad()
and loss.backward()
for the optimization. This mechanism is in place to support optimizers which operate on the
output of the closure (e.g. the loss) or need to call the closure several times (e.g. LBFGS
).
Warning
Before v1.2.2, Lightning internally calls backward
, step
and zero_grad
in the order.
From v1.2.2, the order is changed to zero_grad
, backward
and step
.
Gradient Accumulation¶
Accumulated gradients run K small batches of size N
before doing a backward pass. The effect is a large effective batch size of size KxN
, where N
is the batch size.
Internally it doesn’t stack up the batches and do a forward pass rather it accumulates the gradients for K batches and then do an optimizer.step
to make sure the
effective batch size is increased but there is no memory overhead.
Warning
When using distributed training for eg. DDP, with let’s say with P
devices, each device accumulates independently i.e. it stores the gradients
after each loss.backward()
and doesn’t sync the gradients across the devices until we call optimizer.step()
. So for each accumulation
step, the effective batch size on each device will remain N*K
but right before the optimizer.step()
, the gradient sync will make the effective
batch size as P*N*K
. For DP, since the batch is split across devices, the final effective batch size will be N*K
.
See also
# DEFAULT (ie: no accumulated grads)
trainer = Trainer(accumulate_grad_batches=1)
# Accumulate gradients for 7 batches
trainer = Trainer(accumulate_grad_batches=7)
You can set different values for it at different epochs by passing a dictionary, where the key represents the epoch at which the value for gradient accumulation should be updated.
# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
trainer = Trainer(accumulate_grad_batches={0: 8, 4: 4, 8: 1})
Or, you can create custom GradientAccumulationScheduler
from pytorch_lightning.callbacks import GradientAccumulationScheduler
# till 5th epoch, it will accumulate every 8 batches. From 5th epoch
# till 9th epoch it will accumulate every 4 batches and after that no accumulation
# will happen. Note that you need to use zero-indexed epoch keys here
accumulator = GradientAccumulationScheduler(scheduling={0: 8, 4: 4, 8: 1})
trainer = Trainer(callbacks=accumulator)
Use Multiple Optimizers (like GANs)¶
To use multiple optimizers (optionally with learning rate schedulers), return two or more optimizers from
configure_optimizers()
.
# two optimizers, no schedulers
def configure_optimizers(self):
return Adam(...), SGD(...)
# two optimizers, one scheduler for adam only
def configure_optimizers(self):
opt1 = Adam(...)
opt2 = SGD(...)
optimizers = [opt1, opt2]
lr_schedulers = {"scheduler": ReduceLROnPlateau(opt1, ...), "monitor": "metric_to_track"}
return optimizers, lr_schedulers
# two optimizers, two schedulers
def configure_optimizers(self):
opt1 = Adam(...)
opt2 = SGD(...)
return [opt1, opt2], [StepLR(opt1, ...), OneCycleLR(opt2, ...)]
Under the hood, Lightning will call each optimizer sequentially:
for epoch in epochs:
for batch in data:
for opt in optimizers:
loss = train_step(batch, batch_idx, optimizer_idx)
opt.zero_grad()
loss.backward()
opt.step()
for lr_scheduler in lr_schedulers:
lr_scheduler.step()
Step Optimizers at Arbitrary Intervals¶
To do more interesting things with your optimizers such as learning rate warm-up or odd scheduling,
override the optimizer_step()
function.
Warning
If you are overriding this method, make sure that you pass the optimizer_closure
parameter to
optimizer.step()
function as shown in the examples because training_step()
, optimizer.zero_grad()
,
loss.backward()
are called in the closure function.
For example, here step optimizer A every batch and optimizer B every 2 batches.
# Alternating schedule for optimizer steps (e.g. GANs)
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
# update generator every step
if optimizer_idx == 0:
optimizer.step(closure=optimizer_closure)
# update discriminator every 2 steps
if optimizer_idx == 1:
if (batch_idx + 1) % 2 == 0:
# the closure (which includes the `training_step`) will be executed by `optimizer.step`
optimizer.step(closure=optimizer_closure)
else:
# call the closure by itself to run `training_step` + `backward` without an optimizer step
optimizer_closure()
# ...
# add as many optimizers as you want
Here we add a manual learning rate warm-up without an lr scheduler.
# learning rate warm-up
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
# update params
optimizer.step(closure=optimizer_closure)
# skip the first 500 steps
if self.trainer.global_step < 500:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * self.hparams.learning_rate
Access your Own Optimizer¶
The provided optimizer
is a LightningOptimizer
object wrapping your own optimizer
configured in your configure_optimizers()
.
You can access your own optimizer with optimizer.optimizer
. However, if you use your own optimizer
to perform a step, Lightning won’t be able to support accelerators, precision and profiling for you.
# function hook in LightningModule
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
optimizer.step(closure=optimizer_closure)
# `optimizer` is a `LightningOptimizer` wrapping the optimizer.
# To access it, do the following.
# However, it won't work on TPU, AMP, etc...
def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
optimizer = optimizer.optimizer
optimizer.step(closure=optimizer_closure)
Bring your own Custom Learning Rate Schedulers¶
Lightning allows using custom learning rate schedulers that aren’t available in PyTorch natively.
One good example is Timm Schedulers. When using custom learning rate schedulers
relying on a different API from Native PyTorch ones, you should override the lr_scheduler_step()
with your desired logic.
If you are using native PyTorch schedulers, there is no need to override this hook since Lightning will handle it automatically by default.
from timm.scheduler import TanhLRScheduler
def configure_optimizers(self):
optimizer = ...
scheduler = TanhLRScheduler(optimizer, ...)
return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]
def lr_scheduler_step(self, scheduler, optimizer_idx, metric):
scheduler.step(epoch=self.current_epoch) # timm's scheduler need the epoch value
Configure Gradient Clipping¶
To configure custom gradient clipping, consider overriding
the configure_gradient_clipping()
method.
Attributes gradient_clip_val
and gradient_clip_algorithm
from Trainer will be passed in the
respective arguments here and Lightning will handle gradient clipping for you. In case you want to set
different values for your arguments of your choice and let Lightning handle the gradient clipping, you can
use the inbuilt clip_gradients()
method and pass
the arguments along with your optimizer.
Warning
Make sure to not override clip_gradients()
method. If you want to customize gradient clipping, consider using
configure_gradient_clipping()
method.
For example, here we will apply gradient clipping only to the gradients associated with optimizer A.
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)
Here we configure gradient clipping differently for optimizer B.
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm):
if optimizer_idx == 0:
# Lightning will handle the gradient clipping
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val, gradient_clip_algorithm=gradient_clip_algorithm
)
elif optimizer_idx == 1:
self.clip_gradients(
optimizer, gradient_clip_val=gradient_clip_val * 2, gradient_clip_algorithm=gradient_clip_algorithm
)
Total Stepping Batches¶
You can use built-in trainer property estimated_stepping_batches
to compute
total number of stepping batches for the complete training. The property is computed considering gradient accumulation factor and
distributed setting into consideration so you don’t have to derive it manually. One good example where this can be helpful is while using
OneCycleLR
scheduler, which requires pre-computed total_steps
during initialization.
def configure_optimizers(self):
optimizer = ...
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches
)
return [optimizer], [scheduler]