Dealing with multiple datasets/dataloaders in Lightning

I’m in the situation where I’m setting up a training for multiple datasets of different lengths → different number of batches in the DataLoaders. Since I want to use a loss function which combines losses across different datasets (e.g. weighted differently for each dataset) I was trying to keep dataloaders separately into dictionaries:

def train_dataloader(self): #returns a dict of dataloaders
    train_loaders = {}
    for key, value in self.train_dict.items():
        train_loaders[key] = DataLoader(value,
                                        batch_size = self.batch_size,
                                        collate_fn = collate)
    return train_loaders

Then in training_step() I’m doing the following (I’m working on a contrastive learning project):

def training_step(self, batch, batch_idx):
total_batch_loss = 0
for key, value in batch.items():
    anc, pos, neg  = value
    emb_anc = F.normalize(self.forward(anc.x,
                                       ), 2, dim=1)
    emb_pos = F.normalize(self.forward(pos.x,
                                       ), 2, dim=1)
    emb_neg = F.normalize(self.forward(neg.x,
                                       ), 2, dim=1)
    loss_dataset = LossFunc(emb_anc, emb_pos, emb_neg, anc.y, pos.y, neg.y)
    total_batch_loss += loss_dataset
self.log("Loss", total_batch_loss, prog_bar=True, on_epoch=True)        
return total_batch_loss

The problem is that I have a different number of batches per dataset and then Lightning will throw a StopIteration when the smallest dataset is exhausted. I was considering simply concatenating everything into a single DataLoader, but that way I don’t think I will be able to compute a loss that gets weighted as I want according to each dataset.

When you create the dataloader, you can wrap it into a CombinedLoader and then specify the mode.
Arbitrary iterable support — PyTorch Lightning 2.0.1.post0 documentation.

Would mode="max_size" make more sense for you? This would conclude the epoch only at the longest dataset, and cycle the other datasets that are shorter.

(this answer applies for Lightning 2.0).