combined_loader¶
Classes
Combines different iterables under specific sampling modes. |
- class lightning.pytorch.utilities.combined_loader.CombinedLoader(iterables, mode='min_size')[source]¶
Bases:
Iterable
Combines different iterables under specific sampling modes.
- Parameters:
iterables¶ (
Any
) – the iterable or collection of iterables to sample from.mode¶ (
Literal
['min_size'
,'max_size_cycle'
,'max_size'
,'sequential'
]) –the mode to use. The following modes are supported:
min_size
: stops after the shortest iterable (the one with the lowest number of items) is done.max_size_cycle
: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables.max_size
: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables.sequential
: completely consumes each iterable sequentially, and returns a triplet(data, idx, iterable_idx)
Examples
>>> from torch.utils.data import DataLoader >>> iterables = {'a': DataLoader(range(6), batch_size=4), ... 'b': DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'max_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': None, 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'min_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 2 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0
>>> combined_loader = CombinedLoader(iterables, 'sequential') >>> _ = iter(combined_loader) >>> len(combined_loader) 5 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") tensor([0, 1, 2, 3]), batch_idx=0, dataloader_idx=0 tensor([4, 5]), batch_idx=1, dataloader_idx=0 tensor([0, 1, 2, 3, 4]), batch_idx=0, dataloader_idx=1 tensor([5, 6, 7, 8, 9]), batch_idx=1, dataloader_idx=1 tensor([10, 11, 12, 13, 14]), batch_idx=2, dataloader_idx=1