How to use two train_dataloaders iterate over each epoch?

I have below lighting module, where I have two train_dataloader

class CustomModel(pl.LightningModule):
    def __init__(self, **kwargs):
        pass

    def train_dataloader(self):
        return (
            DataLoader(
                self.train_dataset,
                collate_fn=collate_fn_0(),
            ),
            DataLoader(
                self.mbuffer_dataset,
                collate_fn=collate_fn_1(),
            )
        )
    
runner = Trainer(**kwargs)
runner.fit(CustomModel())

and I like to use this in the following fashion -

for epoch in range(num_epochs):
  ... # code
  for batch in train_dataloaders[0]:
    outputs = model(batch)
  # do something
  for batch in train_dataloaders[1]:
    outputs = model(batch)

How I can do it?

Here is a sketch:

from lightning.pytorch.utilities import CombinedDataLoader

def train_dataloader(self):
        return CombinedDataLoader([
            DataLoader(
                self.train_dataset,
                collate_fn=collate_fn_0(),
            ),
            DataLoader(
                self.mbuffer_dataset,
                collate_fn=collate_fn_1(),
            )
        ], 
        mode="sequential",  # <---- this is important
        )

This should work in Lightning 2.0.

Docs: Arbitrary iterable support — PyTorch Lightning 2.0.1 documentation

1 Like