yeah, my bad it should be learning_rate*N
. fixed it.
Also for the batch_size
, you should set it to the value you want on single device because pl uses this batch_size
on each device so effective will be batch_size*N
. Not sure about learning_rate yet.
what do you mean by this statement?