I’m working in an environment that has regular HDDs, shared amongst many users. I/O performance is too poor to simply read and parse data on the fly, so I have to load my data in memory.
I have a single node with 4 GPUs (Node resources are not shared, underlying storage is). When training in DDP mode, each process loads the entire dataset in memory, which although works for my current dataset, won’t work for larger ones, and I’d like to avoid that since each process uses only a subset of the data anyway and the rest is redundant. The dataset preparation is done directly in the LightningModule (without explicitly using LightningDataModule)
From my understanding of things, the following should solve this problem:
- Disable adding of Distributed Sampler in Trainer using
replace_sampler_ddp=False
- pass local rank information to the Dataset and load a particular shard.
So my questions are:
- How do I achieve the above, i.e. getting rank information of the process in the LightningModule and passing it on to my dataset object?
- Is there a better way to do this using existing pytorch-lightning components?