My problem is pretty similar to this one, but I haven’t found an answer here or elsewhere.
I’m using PyTorch Lightning in combination with WebDataset, a third-party IterableDataset. In order to ensure that each DDP process receives the same number of batches, I need each DDP process to know the number of samples in the dataset. However, since the user is allowed to filter certain samples in the dataset using command-line parameters (e.g. they can exclude examples above a certain length), this value cannot be known in advance. Computing it directly in my DataModule doesn’t work, because prepare_data() can’t be used to share state, and setup() would run the costly filtering computations once for every GPU.
Is there a more elegant way to share this sort of data between the processes than resorting to environment variables/temporary files?