GradientAccumulationScheduler¶
- class pytorch_lightning.callbacks.GradientAccumulationScheduler(scheduling)[source]¶
Bases:
pytorch_lightning.callbacks.base.Callback
Change gradient accumulation factor according to scheduling.
- Parameters
scheduling¶ (
Dict
[int
,int
]) – scheduling in format {epoch: accumulation_factor}- Raises
TypeError – If
scheduling
is an emptydict
, or not all keys and values ofscheduling
are integers.IndexError – If
minimal_epoch
is less than 0.
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # at epoch 5 start accumulating every 2 batches >>> accumulator = GradientAccumulationScheduler(scheduling={5: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={5: 2})