• Docs >
  • Efficient Gradient Accumulation

Efficient Gradient Accumulation

Gradient accumulation works the same way with Fabric as in PyTorch. You are in control of which model accumulates and at what frequency:

for iteration, batch in enumerate(dataloader):

    # Accumulate gradient 8 batches at a time
    is_accumulating = iteration % 8 != 0

    output = model(input)
    loss = ...

    # .backward() accumulates when .zero_grad() wasn't called

    if not is_accumulating:
        # Step the optimizer after accumulation phase is over

However, in a distributed setting, for example when training across multiple GPUs or machines, doing it this way can slow down your training loop significantly. In order to optimize this code, we should skip the synchronization in .backward() during the accumulation phase. We only need to synchronize the gradients when the accumulation phase is over! This can be achieved by adding the no_backward_sync() context manager over the backward() call:

  for iteration, batch in enumerate(dataloader):

      # Accumulate gradient 8 batches at a time
      is_accumulating = iteration % 8 != 0

+     with fabric.no_backward_sync(model, enabled=is_accumulating):
          output = model(input)
          loss = ...

          # .backward() accumulates when .zero_grad() wasn't called


      if not is_accumulating:
          # Step the optimizer after accumulation phase is over

For those strategies that don’t support it, a warning is emitted. For single-device strategies, it is a no-op. Both the model’s .forward() and the fabric.backward() call need to run under this context.

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.