Shortcuts

GradientAccumulationScheduler

class lightning.pytorch.callbacks.GradientAccumulationScheduler(scheduling)[source]

Bases: lightning.pytorch.callbacks.callback.Callback

Change gradient accumulation factor according to scheduling.

Parameters

scheduling (Dict[int, int]) – scheduling in format {epoch: accumulation_factor}

Note

The argument scheduling is a dictionary. Each key represent an epoch and its associated accumulation factor value. Warning: Epoch are zero-indexed c.f it means if you want to change the accumulation factor after 4 epochs, set Trainer(accumulate_grad_batches={4: factor}) or GradientAccumulationScheduler(scheduling={4: factor}). For more info check the example below.

Raises
  • TypeError – If scheduling is an empty dict, or not all keys and values of scheduling are integers.

  • IndexError – If minimal_epoch is less than 0.

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import GradientAccumulationScheduler

# from epoch 5, it starts accumulating every 2 batches. Here we have 4 instead of 5
# because epoch (key) should be zero-indexed.
>>> accumulator = GradientAccumulationScheduler(scheduling={4: 2})
>>> trainer = Trainer(callbacks=[accumulator])
on_train_epoch_start(trainer, *_)[source]

Called when the train epoch begins.

Return type

None

on_train_start(trainer, pl_module)[source]

Performns a configuration validation before training starts and raises errors for incompatible settings.

Return type

None

You are viewing an outdated version of PyTorch Lightning Docs

Click here to view the latest version→