Global_step increased at new epoch regardless of gradient accumulation

Hello, so in my setting, the step count is the most important, and I will cut off training with max_steps. I found that some of the checkpoint file’s names are drifting away from a multiple of val_check_interval, which I set to N * grad_accumulation, thinking about N effective batches.

Then I found that for every new epoch, the trainer would increase the global_step regardless of whether a full set of grad accumulation batches is done (drop_last=False). Is this the expected behavior? Is the optimizer being stepped at the end disregarding the accumulation setting, or is it just the global_step being wrongly increased?

Is this the expected behavior? Is the optimizer being stepped at the end disregarding the accumulation setting, or is it just the global_step is wrongly increased?

This is a limitation of the epoch-based training. If the grad_accumulation steps don’t evenly divide the epoch size, it happens that we reach the end of the epoch without having completed the full grad_accumulation. We have to make a decision: A) Do we drop the gradients we have collected so far and don’t step the optimizer? B) Or should we step the optimizer and use the gradients we have computed so far?

Lightning takes the second approach (B) so that all the training data from that epoch is included in the updates until the end of the epoch. There is also a third option: one could also keep the gradients and accumulated so far, run the epoch end logic, and when entering the new epoch continue where accumulation left of. But this turns out to be quite nasty to implement in the training loop. Also, we would need to hold the graph in memory while running the epoch-end, or checkpoint it somehow, and extra steps would have to be taken so that this does not interfere with the validation that runs at the epoch end (e.g. memory).

Fro the top of my head, the same principle also extends to the boundaries around the val_check_interval. So if this interval is more than once per epoch, the same reasoning applies at these boundaries.

I know this doesn’t solve your problem but I hope the explanation provides a bit more context. Maybe what you could try is to truncate your dataset slightly in such a way that the accumulation factor divides the size. What do you think?

I understand the reasoning, and I agree the implementation would be very difficult to keep the grads for the next epoch.

For my task (similar to NLP) specifically, the batch size actually depends on max_tokens, and we created a sampler that batches similar-length samples (of course not strictly sorted) together. The artifact is that we are always stepping through shorter sentences first. Leaving the last batch is unfair for long sentences. What I can think of is to first randomly choose samples to be left out and put them at the end (so to be fair)

Thanks for the immediate reply!

1 Like