• Docs >
  • Arbitrary iterable support
Shortcuts

Arbitrary iterable support

Python iterables are objects that can be iterated or looped over. Examples of iterables in Python include lists and dictionaries. In PyTorch, a torch.utils.data.DataLoader is also an iterable which typically retrieves data from a torch.utils.data.Dataset or torch.utils.data.IterableDataset.

The Trainer works with arbitrary iterables, but most people will use a torch.utils.data.DataLoader as the iterable to feed data to the model.

Multiple Iterables

In addition to supporting arbitrary iterables, the Trainer also supports arbitrary collections of iterables. Some examples of this are:

return DataLoader(...)
return list(range(1000))

# pass loaders as a dict. This will create batches like this:
# {'a': batch_from_loader_a, 'b': batch_from_loader_b}
return {"a": DataLoader(...), "b": DataLoader(...)}

# pass loaders as list. This will create batches like this:
# [batch_from_dl_1, batch_from_dl_2]
return [DataLoader(...), DataLoader(...)]

# {'a': [batch_from_dl_1, batch_from_dl_2], 'b': [batch_from_dl_3, batch_from_dl_4]}
return {"a": [dl1, dl2], "b": [dl3, dl4]}

Lightning automatically collates the batches from multiple iterables based on a “mode”. This is done with our CombinedLoader class. The list of modes available can be found by looking at the mode documentation.

By default, the "max_size_cycle" mode is used during training and the "sequential" mode is used during validation, testing, and prediction. To choose a different mode, you can use the CombinedLoader class directly with your mode of choice:

from lightning.pytorch.utilities import CombinedLoader

iterables = {"a": DataLoader(), "b": DataLoader()}
combined_loader = CombinedLoader(iterables, mode="min_size")
model = ...
trainer = Trainer()
trainer.fit(model, combined_loader)

Currently, trainer.validate, trainer.test, and trainer.predict methods only support the "sequential" mode, while trainer.fit method does not support it. Support for this feature is tracked in this issue.

Note that when using the "sequential" mode, you need to add an additional argument dataloader_idx to some specific hooks. Lightning will raise an error informing you of this requirement.

Using LightningDataModule

You can set more than one DataLoader in your LightningDataModule using its DataLoader hooks and Lightning will use the correct one.

class DataModule(LightningDataModule):
    def train_dataloader(self):
        # any iterable or collection of iterables
        return DataLoader(self.train_dataset)

    def val_dataloader(self):
        # any iterable or collection of iterables
        return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)]

    def test_dataloader(self):
        # any iterable or collection of iterables
        return DataLoader(self.test_dataset)

    def predict_dataloader(self):
        # any iterable or collection of iterables
        return DataLoader(self.predict_dataset)

Using LightningModule Hooks

The exact same code as above works when overriding LightningModule

Passing the iterables to the Trainer

The same support for arbitrary iterables, or collection of iterables applies to the dataloader arguments of fit(), validate(), test(), predict()


© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.