combined_loader¶
Classes
Combines different iterables under specific sampling modes. |
- class lightning.pytorch.utilities.combined_loader.CombinedLoader(iterables, mode='min_size')[source]¶
Bases:
Iterable
Combines different iterables under specific sampling modes.
- Parameters
iterables¶ (
Any
) – the iterable or collection of iterables to sample from.mode¶ (
Literal
[‘min_size’, ‘max_size_cycle’, ‘max_size’, ‘sequential’]) –the mode to use. The following modes are supported:
min_size
: stops after the shortest iterable (the one with the lowest number of items) is done.max_size_cycle
: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables.max_size
: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables.sequential
: completely consumes ecah iterable sequentially, and returns a triplet(data, idx, iterable_idx)
Examples
>>> from torch.utils.data import DataLoader >>> iterables = {'a': DataLoader(range(6), batch_size=4), ... 'b': DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> len(combined_loader) 3 >>> for batch in combined_loader: ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}
>>> combined_loader = CombinedLoader(iterables, 'max_size') >>> len(combined_loader) 3 >>> for batch in combined_loader: ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])} {'a': None, 'b': tensor([10, 11, 12, 13, 14])}
>>> combined_loader = CombinedLoader(iterables, 'min_size') >>> len(combined_loader) 2 >>> for batch in combined_loader: ... print(batch) {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])} {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}
>>> combined_loader = CombinedLoader(iterables, 'sequential') >>> len(combined_loader) 5 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch} {batch_idx=} {dataloader_idx=}") tensor([0, 1, 2, 3]) batch_idx=0 dataloader_idx=0 tensor([4, 5]) batch_idx=1 dataloader_idx=0 tensor([0, 1, 2, 3, 4]) batch_idx=0 dataloader_idx=1 tensor([5, 6, 7, 8, 9]) batch_idx=1 dataloader_idx=1 tensor([10, 11, 12, 13, 14]) batch_idx=2 dataloader_idx=1