Effective learning rate and batch size with Lightning in DDP

  1. Consider the MSE loss for example, it is typically computed by averaging the sample MSE, namely
    1/N sum_i(MSE_i)
    it is therefore required to scale the accumulated gradients by N in DDP to mimic the same behavior.

  2. As far as I know, learning rate is scaled with the batch size so that the sample variance of the gradients is kept approx. constant. In a sense, the larger the batch we expect that the average gradients over the samples will be in the correct direction. Since Var(aX)=a^2Var(x) for some constant a!=0, we typically scale the learning rate by sqrt(effective_batch_size/baseline_batch_size).
    See the example I gave here