Hello,
I’m training a model using an IterableDataset on multiple GPUs. There is an issue where if the number of batches is uneven between workers then training hangs. After some research it looks like in vanilla PyTorch one can use a join context to solve the issue but that this isn’t a supported yet with lightning (the issue is discussed here but still open).
I would think Lightning + multi-GPU + IterableDataset is quite a common setup for large datasets that need to be streamed so I’m a little bit surprised as I couldn’t find any workarounds or suggestions of how to deal with this issue? Interested to hear how people are dealing with it or if I’m missing something.