class pytorch_lightning.callbacks.GradientAccumulationScheduler(scheduling)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

Change gradient accumulation factor according to scheduling.


scheduling (Dict[int, int]) – scheduling in format {epoch: accumulation_factor}


The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor}) or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.


>>> from pytorch_lightning import Trainer
>>> from pytorch_lightning.callbacks import GradientAccumulationScheduler

# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
# because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])

# alternatively, pass the scheduling dict directly to the Trainer
>>> trainer = Trainer(accumulate_grad_batches={4: 2})
on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type: