combined_loader

Classes

CombinedLoader

Combines different iterables under specific sampling modes.

class lightning.pytorch.utilities.combined_loader.CombinedLoader(iterables, mode='min_size')[source]

Bases: Iterable

Combines different iterables under specific sampling modes.

Parameters:
  • iterables (Any) – the iterable or collection of iterables to sample from.

  • mode (Literal['min_size', 'max_size_cycle', 'max_size', 'sequential']) –

    the mode to use. The following modes are supported:

    • min_size: stops after the shortest iterable (the one with the lowest number of items) is done.

    • max_size_cycle: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables.

    • max_size: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables.

    • sequential: completely consumes each iterable sequentially, and returns a triplet (data, idx, iterable_idx)

Examples

>>> from torch.utils.data import DataLoader
>>> iterables = {'a': DataLoader(range(6), batch_size=4),
...              'b': DataLoader(range(15), batch_size=5)}
>>> combined_loader = CombinedLoader(iterables, 'max_size_cycle')
>>> _ = iter(combined_loader)
>>> len(combined_loader)
3
>>> for batch, batch_idx, dataloader_idx in combined_loader:
...     print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0
{'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'max_size')
>>> _ = iter(combined_loader)
>>> len(combined_loader)
3
>>> for batch, batch_idx, dataloader_idx in combined_loader:
...     print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0
{'a': None, 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'min_size')
>>> _ = iter(combined_loader)
>>> len(combined_loader)
2
>>> for batch, batch_idx, dataloader_idx in combined_loader:
...     print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
{'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0
{'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'sequential')
>>> _ = iter(combined_loader)
>>> len(combined_loader)
5
>>> for batch, batch_idx, dataloader_idx in combined_loader:
...     print(f"{batch}, {batch_idx=}, {dataloader_idx=}")
tensor([0, 1, 2, 3]), batch_idx=0, dataloader_idx=0
tensor([4, 5]), batch_idx=1, dataloader_idx=0
tensor([0, 1, 2, 3, 4]), batch_idx=0, dataloader_idx=1
tensor([5, 6, 7, 8, 9]), batch_idx=1, dataloader_idx=1
tensor([10, 11, 12, 13, 14]), batch_idx=2, dataloader_idx=1
reset()[source]

Reset the state and shutdown any workers.

Return type:

None

property batch_sampler: Any

Return a collections of batch samplers extracted from iterables.

property flattened: List[Any]

Return the flat list of iterables.

property iterables: Any

Return the original collection of iterables.

property limits: Optional[List[Union[int, float]]]

Optional limits per iterator.

property sampler: Any

Return a collections of samplers extracted from iterables.