GradientAccumulationScheduler¶
- class pytorch_lightning.callbacks.GradientAccumulationScheduler(scheduling)[source]¶
- Bases: - pytorch_lightning.callbacks.base.Callback- Change gradient accumulation factor according to scheduling. - Note - The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set - Trainer(accumulate_grad_batches={4: factor})or- GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.- Raises
- TypeError – If - schedulingis an empty- dict, or not all keys and values of- schedulingare integers.
- IndexError – If - minimal_epochis less than 0.
 
 - Example: - >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import GradientAccumulationScheduler # from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5 because epoch (key) should be zero-indexed. >>> accumulator = GradientAccumulationScheduler(scheduling={4: 2}) >>> trainer = Trainer(callbacks=[accumulator]) # alternatively, pass the scheduling dict directly to the Trainer >>> trainer = Trainer(accumulate_grad_batches={4: 2})