GradientAccumulationScheduler¶
- class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[source]¶
Bases:
lightning.pytorch.callbacks.callback.Callback
Change gradient accumulation factor according to scheduling.
Note
The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set
Trainer(accumulate_grad_batches={4: factor})
orGradientAccumulationScheduler(scheduling={4: factor})
. For more info check the example below.- Raises
TypeError – If
scheduling
is an emptydict
, or not all keys and values ofscheduling
are integers.IndexError – If
minimal_epoch
is less than 0.
Example:
>>> from lightning.pytorch import Trainer >>> from lightning.pytorch.callbacks import GradientAccumulationScheduler # from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5 # because epoch (key) should be zero-indexed. >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator])