The lightning data module mentions:
def prepare_data(self):
# download, split, etc...
# only called on 1 GPU/TPU in distributed
def setup(self):
# make assignments here (val/train/test split)
# called on every process in DDP
I’m a bit confused here. If I don’t have any data to download. I load multiple torch.Dataset for the train/test/val and then later they are passed in the pl_dataloader functions but I’m assuming that I will load all the dataset in the setup() function and not prepare_data(). So does this mean that the whole dataset will be loaded on all GPU’s I’m planning to use for training?
Also, how does lightning pass batches for multi-GPU training? I thought the dataset is on the CPU and then Lightning takes a batch from _dataloader and passes it to a GPU so in an epoch the same batch won’t be sent to multiple GPUs for training? But, if I load my dataset in prepare_data will multiple copies be made? Also, if there are no transformations/calculations to be performed then would everything happen in the CPU once?