Let’s say I trained a network with batch size 64, a Trainer with accumulate_grad_batches = 1 and learning rate 0.01. Now I want to training with batch size 32 and accumulate_grad_batches = 2, shall I scale the learning rate by 0.5, as now the gradient is the sum of the gradient of two batches? Thanks!
Hey
If you sum the loss over the batch, the two will be equivalent. But often in practice, the loss is the mean over the batch, so in this case the gradients need to be scaled for them to be equivelent. Or, as you said, the same can be achieved with scaling the learning rate.
I made a quick PyTorch example to demonstrate this:
import torch
import torch.nn as nn
torch.manual_seed(1)
batch = torch.rand(2, 4)
label = torch.rand(2, 2)
def batched(reduction, lr=0.1):
torch.manual_seed(2)
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
reduction((model(batch) - label).abs()).backward()
gradients_0 = torch.cat([p.grad.clone().view(-1) for p in model.parameters()])
optimizer.step()
params_0 = torch.cat([p.data.clone().view(-1) for p in model.parameters()])
return params_0, gradients_0
def accumulation(reduction, lr=0.1):
torch.manual_seed(2)
model = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 2))
optimizer = torch.optim.SGD(model.parameters(), lr=lr) # try 0.05
loss_0 = (model(batch[0].view(1, -1)) - label[0].view(1, -1)).abs()
reduction(loss_0).backward()
loss_1 = (model(batch[1].view(1, -1)) - label[1].view(1, -1)).abs()
reduction(loss_1).backward()
gradients_1 = torch.cat([p.grad.clone().view(-1) for p in model.parameters()])
optimizer.step()
params_1 = torch.cat([p.data.clone().view(-1) for p in model.parameters()])
return params_1, gradients_1
# Gradient accumulation is equivalent to batched optimization if loss summed over batch
p0, g0 = batched(reduction=torch.sum)
p1, g1 = accumulation(reduction=torch.sum)
print("gradients equal", torch.allclose(g0, g1))
print("gradients equal", torch.allclose(p0, p1))
# If the loss is the mean over batches, you need to scale the gradients
# (or change the learning rate, or scale the loss)
p0, g0 = batched(reduction=torch.mean, lr=1.0)
p1, g1 = accumulation(reduction=torch.mean, lr=0.5)
print("gradients equal", torch.allclose(g0, g1))
print("gradients equal", torch.allclose(p0, p1))
1 Like