Arbitrary iterable support¶
Python iterables are objects that can be iterated or looped over. Examples of iterables in Python include lists and dictionaries.
In PyTorch, a torch.utils.data.DataLoader
is also an iterable which typically retrieves data from a torch.utils.data.Dataset
or torch.utils.data.IterableDataset
.
The Trainer
works with arbitrary iterables, but most people will use a torch.utils.data.DataLoader
as the iterable to feed data to the model.
Multiple Iterables¶
In addition to supporting arbitrary iterables, the Trainer
also supports arbitrary collections of iterables. Some examples of this are:
return DataLoader(...)
return list(range(1000))
# pass loaders as a dict. This will create batches like this:
# {'a': batch_from_loader_a, 'b': batch_from_loader_b}
return {"a": DataLoader(...), "b": DataLoader(...)}
# pass loaders as list. This will create batches like this:
# [batch_from_dl_1, batch_from_dl_2]
return [DataLoader(...), DataLoader(...)]
# {'a': [batch_from_dl_1, batch_from_dl_2], 'b': [batch_from_dl_3, batch_from_dl_4]}
return {"a": [dl1, dl2], "b": [dl3, dl4]}
Lightning automatically collates the batches from multiple iterables based on a “mode”. This is done with our
CombinedLoader
class.
The list of modes available can be found by looking at the mode
documentation.
By default, the "max_size_cycle"
mode is used during training and the "sequential"
mode is used during validation, testing, and prediction.
To choose a different mode, you can use the CombinedLoader
class directly with your mode of choice:
from lightning.pytorch.utilities import CombinedLoader
iterables = {"a": DataLoader(), "b": DataLoader()}
combined_loader = CombinedLoader(iterables, mode="min_size")
model = ...
trainer = Trainer()
trainer.fit(model, combined_loader)
Currently, trainer.validate
, trainer.test
, and trainer.predict
methods only support the "sequential"
mode, while trainer.fit
method does not support it.
Support for this feature is tracked in this issue.
Note that when using the "sequential"
mode, you need to add an additional argument dataloader_idx
to some specific hooks.
Lightning will raise an error informing you of this requirement.
Using LightningDataModule¶
You can set more than one DataLoader
in your LightningDataModule
using its DataLoader hooks
and Lightning will use the correct one.
class DataModule(LightningDataModule):
def train_dataloader(self):
# any iterable or collection of iterables
return DataLoader(self.train_dataset)
def val_dataloader(self):
# any iterable or collection of iterables
return [DataLoader(self.val_dataset_1), DataLoader(self.val_dataset_2)]
def test_dataloader(self):
# any iterable or collection of iterables
return DataLoader(self.test_dataset)
def predict_dataloader(self):
# any iterable or collection of iterables
return DataLoader(self.predict_dataset)
Using LightningModule Hooks¶
The exact same code as above works when overriding LightningModule
Passing the iterables to the Trainer¶
The same support for arbitrary iterables, or collection of iterables applies to the dataloader arguments of
fit()
, validate()
,
test()
, predict()