Multi-GPU/Multi-Node training with WebDataset

This is my first time using WebDataset and I have multiple shards (about 60) with a large number of images. It was working as I would expect in the normal Dataset class when I was using a single GPU. However once I set the devices flag in Trainer to 2 I received the error ValueError: you need to add an explicit nodesplitter to your input pipeline for multi-node training Webdataset.

I saw two approach to allow using multiple gpus with WebDataset.

  1. Using .with_epochs
    According to WebDataset Github I could simply use the with_epochs function in my dataset as follows:
dataset = wds.WebDataset(url, resampled=True).shuffle(1000).decode("rgb").to_tuple("png", "json").map(preprocess).with_epoch(10000)
dataloader = wds.WebLoader(dataset, batch_size=batch_size)
  1. Using ddp_equalize
    According to WebDataset MultiNode
dataset_size, batch_size = 1282000, 64
dataset = wds.WebDataset(urls).decode("pil").shuffle(5000).batched(batch_size, partial=False)
loader = wds.WebLoader(dataset, num_workers=4)
loader = loader.ddp_equalize(dataset_size // batch_size)

Could someone please help me understand what is happening in these two pieces of code. In the second case is the dataset_size just a nominal size? Which if any is better. I would also appreciate if someone has an example of what is the best way to use Webdataset with pytorch lightning in multi-gpu and multi-node scenario.

Hey @adhakal224 ,

According to WebDataset MultiNode the 2.) snippet you posted would be required but not sufficient on it’s own.

How DDP works is that it basically splits your dataset into N random subsets (with N being the number of processes/GPUs to train on simultaneously).

This is required as ideally each of these training processes would retrieve a different part of the data to increase the effective batchsize during training (using N times the very same batch does not make sense).
This is ensured by the first part of the WebDataset MultiNode:

dataset = wds.WebDataset(urls, splitter=my_split_by_worker, nodesplitter=my_split_by_node)

most important here for DDP is the splitter. The nodesplitter would allow you to ensure that all workers on node0 only retrieve a specific part of the data, but to my understanding it isn’t strictly required if you want to randomly split across processes (=workers).

The second part from that link

loader = wds.WebLoader(dataset, num_workers=4)
loader = loader.ddp_equalize(dataset_size // batch_size)

is important as DDP requires each process to handle exactly the same number of batches/samples and if your dataset isn’t evenly divisible by your batchsize, this is not guaranteed by default. What ddp_equalize does is to just drop the last part of your dataset to satisfy this condition.

DistributedDataParallel training requires that each participating node receive exactly the same number of training batches as all others. […] You need to give the total number of batches in your dataset to ddp_equalize; it will compute the batches per node from this and equalize batches accordingly.

You need to apply ddp_equalize to the WebLoader rather than the Dataset.


Hi @justusschock! Thanks for the reply. I have a couple of questions. Firstly, In the line:
loader = loader.ddp_equalize(dataset_size // batch_size)
how would I know the dataset_size. I have about 100 shards with tens of millions of data points in total. Do I need to iterate through them all and find the true dataset size? If that is the case, is there a work around for that?

Secondly, my current implementation follows the first approach and looks roughly like this:

self.dataset = wds.WebDataset(self.args.train_path, resampled=True)
self.dataset = self.dataset.shuffle(1000).decode('pil').to_tuple("a.jpg", "b.jpg", "metadata.json","__key__").map(self.do_transforms)

trainloader = wds.WebLoader(self.trainset, batch_size=None,
                    shuffle=False, pin_memory=True, num_workers=self.hparams.num_workers)
trainloader = trainloader.unbatched().shuffle(1000).batched(self.hparams.train_batch_size)

The train_epoch_length here is just some arbitrary value (I am using 10000). I got to this by following WebDataset. At least on a surface level the code seems to be working. Is there a reason this might fail or might be doing something unintended?

Hi @adhakal224 ,

I am not too familiar with the internals of WebDataset, but if they don’t do some automatic DDP-aware sampling (which I assume they don’t as they explicitly state to use the splitter), you’ll end up using the same data on every rank of the DDP training meaning that you wouldn’t have a speedup over a single-device training since visiting the same data on each device will result in effective gradients and thereby updates from a single batch only instead of having different batches per device.