GradientAccumulationScheduler

class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[source]

Bases: Callback

Change gradient accumulation factor according to scheduling.

Parameters:

scheduling (dict[int, int]) – scheduling in format {epoch: accumulation_factor}

Note

The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor}) or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

Raises:
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler

# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
# because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])
on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type:

None

on_train_start(trainer, pl_module)[source]

Performns a configuration validation before training starts and raises errors for incompatible settings.

Return type:

None