Lightning + multi-GPU + IterableDataset uneven batches

Hello,

I’m training a model using an IterableDataset on multiple GPUs. There is an issue where if the number of batches is uneven between workers then training hangs. After some research it looks like in vanilla PyTorch one can use a join context to solve the issue but that this isn’t a supported yet with lightning (the issue is discussed here but still open).

I would think Lightning + multi-GPU + IterableDataset is quite a common setup for large datasets that need to be streamed so I’m a little bit surprised as I couldn’t find any workarounds or suggestions of how to deal with this issue? Interested to hear how people are dealing with it or if I’m missing something.

I have the same issue, did you find a solution?

Hi Dolores,

Unfortunately I didn’t find a proper workaround so I had to just throw away a few batches away at the end of the epoch/validation set to ensure all workers yield exactly the same number of batches in the iterator. It works but it’s not ideal (for example it means if we change batch size or number of devices etc. different batches will be discarded which isn’t great when comparing metrics on validation set etc.). If you find a better solution I’d be keen to hear!