Source code for lightning.pytorch.utilities.combined_loader

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import contextlib
from collections.abc import Iterable
from typing import Any, Callable, Dict, Iterator, List, Literal, Optional, Tuple, Type, Union

from torch.utils.data.dataloader import _BaseDataLoaderIter, _MultiProcessingDataLoaderIter
from typing_extensions import Self, TypedDict, override

from lightning.fabric.utilities.data import sized_len
from lightning.fabric.utilities.types import _Stateful
from lightning.pytorch.utilities._pytree import _map_and_unflatten, _tree_flatten, tree_unflatten

_ITERATOR_RETURN = Tuple[Any, int, int]  # batch, batch_idx, dataloader_idx


class _ModeIterator(Iterator[_ITERATOR_RETURN]):
    def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None:
        if limits is not None and len(limits) != len(iterables):
            raise ValueError(f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(iterables)})")
        self.iterables = iterables
        self.iterators: List[Iterator] = []
        self._idx = 0  # what would be batch_idx
        self.limits = limits

    @override
    def __next__(self) -> _ITERATOR_RETURN:
        raise NotImplementedError

    @override
    def __iter__(self) -> Self:
        self.iterators = [iter(iterable) for iterable in self.iterables]
        self._idx = 0
        return self

    def __len__(self) -> int:
        raise NotImplementedError

    def reset(self) -> None:
        self.iterators = []
        self._idx = 0

    def __getstate__(self) -> Dict[str, Any]:
        state = self.__dict__.copy()

        # workaround an inconvenient `NotImplementedError`:
        # https://github.com/pytorch/pytorch/blob/v2.0.0/torch/utils/data/dataloader.py#L652-L658
        state["iterators"] = [
            None if isinstance(iterator, _BaseDataLoaderIter) else iterator_state
            for iterator, iterator_state in zip(self.iterators, state["iterators"])
        ]

        return state


class _MaxSizeCycle(_ModeIterator):
    def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None:
        super().__init__(iterables, limits)
        self._consumed: List[bool] = []

    @override
    def __next__(self) -> _ITERATOR_RETURN:
        n = len(self.iterators)
        out = [None] * n  # values per iterator
        for i in range(n):
            try:
                out[i] = next(self.iterators[i])
            except StopIteration:
                self._consumed[i] = True
                if all(self._consumed):
                    raise
                # reset the consumed dataloader
                self.iterators[i] = iter(self.iterables[i])
                out[i] = next(self.iterators[i])
        index = self._idx
        self._idx += 1
        return out, index, 0

    @override
    def __iter__(self) -> Self:
        super().__iter__()
        self._consumed = [False] * len(self.iterables)
        return self

    @override
    def __len__(self) -> int:
        lengths = _get_iterables_lengths(self.iterables)
        if self.limits is not None:
            return max(min(length, limit) for length, limit in zip(lengths, self.limits))  # type: ignore[return-value]
        return max(lengths)  # type: ignore[return-value]

    @override
    def reset(self) -> None:
        super().reset()
        self._consumed = []


class _MinSize(_ModeIterator):
    @override
    def __next__(self) -> _ITERATOR_RETURN:
        out = [next(it) for it in self.iterators]
        index = self._idx
        self._idx += 1
        return out, index, 0

    @override
    def __len__(self) -> int:
        lengths = _get_iterables_lengths(self.iterables)
        return min(lengths + self.limits) if self.limits is not None else min(lengths)  # type: ignore[return-value]


class _Sequential(_ModeIterator):
    def __init__(self, iterables: List[Iterable], limits: Optional[List[Union[int, float]]] = None) -> None:
        super().__init__(iterables, limits)
        self._iterator_idx = 0  # what would be dataloader_idx

    @override
    def __next__(self) -> _ITERATOR_RETURN:
        n = len(self.iterables)
        if n == 0 or self._iterator_idx >= n:
            raise StopIteration

        # if limits are set, go to the correct iterator
        if self.limits is not None:
            while self.limits[self._iterator_idx] <= self._idx:
                self._use_next_iterator()
                if self._iterator_idx >= n:
                    raise StopIteration

        try:
            out = next(self.iterators[0])
        except StopIteration:
            # try the next iterator
            self._use_next_iterator()
            return self.__next__()
        index = self._idx
        self._idx += 1
        return out, index, self._iterator_idx

    @override
    def __iter__(self) -> Self:
        self._iterator_idx = 0
        self._idx = 0
        self._load_current_iterator()
        return self

    @override
    def __len__(self) -> int:
        lengths = _get_iterables_lengths(self.iterables)
        if self.limits is not None:
            return sum(min(length, limit) for length, limit in zip(lengths, self.limits))  # type: ignore[misc]
        return sum(lengths)  # type: ignore[arg-type]

    @override
    def reset(self) -> None:
        super().reset()
        self._iterator_idx = 0

    def _load_current_iterator(self) -> None:
        # Load a single DataLoader, prevents multiple sets of workers from starting unnecessarily
        if self._iterator_idx < len(self.iterables):
            self.iterators = [iter(self.iterables[self._iterator_idx])]
        else:
            # No more iterables to step through, return an empty list
            self.iterators = []

    def _use_next_iterator(self) -> None:
        self._iterator_idx += 1
        self._idx = 0
        self._load_current_iterator()


class _MaxSize(_ModeIterator):
    @override
    def __next__(self) -> _ITERATOR_RETURN:
        n = len(self.iterators)
        out = [None] * n
        all_exhausted = True
        for i in range(n):
            with contextlib.suppress(StopIteration):
                out[i] = next(self.iterators[i])
                all_exhausted = False
        if all_exhausted:
            raise StopIteration
        index = self._idx
        self._idx += 1
        return out, index, 0

    @override
    def __len__(self) -> int:
        lengths = _get_iterables_lengths(self.iterables)
        if self.limits is not None:
            return max(min(length, limit) for length, limit in zip(lengths, self.limits))  # type: ignore[return-value]
        return max(lengths)  # type: ignore[return-value]


class _CombinationMode(TypedDict):
    fn: Callable[[List[int]], int]
    iterator: Type[_ModeIterator]


_SUPPORTED_MODES = {
    "min_size": _CombinationMode(fn=min, iterator=_MinSize),
    "max_size_cycle": _CombinationMode(fn=max, iterator=_MaxSizeCycle),
    "max_size": _CombinationMode(fn=max, iterator=_MaxSize),
    "sequential": _CombinationMode(fn=sum, iterator=_Sequential),
}
_LITERAL_SUPPORTED_MODES = Literal["min_size", "max_size_cycle", "max_size", "sequential"]


[docs]class CombinedLoader(Iterable): """Combines different iterables under specific sampling modes. Args: iterables: the iterable or collection of iterables to sample from. mode: the mode to use. The following modes are supported: * ``min_size``: stops after the shortest iterable (the one with the lowest number of items) is done. * ``max_size_cycle``: stops after the longest iterable (the one with most items) is done, while cycling through the rest of the iterables. * ``max_size``: stops after the longest iterable (the one with most items) is done, while returning None for the exhausted iterables. * ``sequential``: completely consumes each iterable sequentially, and returns a triplet ``(data, idx, iterable_idx)`` Examples: >>> from torch.utils.data import DataLoader >>> iterables = {'a': DataLoader(range(6), batch_size=4), ... 'b': DataLoader(range(15), batch_size=5)} >>> combined_loader = CombinedLoader(iterables, 'max_size_cycle') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': tensor([0, 1, 2, 3]), 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'max_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 3 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 {'a': None, 'b': tensor([10, 11, 12, 13, 14])}, batch_idx=2, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'min_size') >>> _ = iter(combined_loader) >>> len(combined_loader) 2 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") {'a': tensor([0, 1, 2, 3]), 'b': tensor([0, 1, 2, 3, 4])}, batch_idx=0, dataloader_idx=0 {'a': tensor([4, 5]), 'b': tensor([5, 6, 7, 8, 9])}, batch_idx=1, dataloader_idx=0 >>> combined_loader = CombinedLoader(iterables, 'sequential') >>> _ = iter(combined_loader) >>> len(combined_loader) 5 >>> for batch, batch_idx, dataloader_idx in combined_loader: ... print(f"{batch}, {batch_idx=}, {dataloader_idx=}") tensor([0, 1, 2, 3]), batch_idx=0, dataloader_idx=0 tensor([4, 5]), batch_idx=1, dataloader_idx=0 tensor([0, 1, 2, 3, 4]), batch_idx=0, dataloader_idx=1 tensor([5, 6, 7, 8, 9]), batch_idx=1, dataloader_idx=1 tensor([10, 11, 12, 13, 14]), batch_idx=2, dataloader_idx=1 """ def __init__(self, iterables: Any, mode: _LITERAL_SUPPORTED_MODES = "min_size") -> None: if mode not in _SUPPORTED_MODES: raise ValueError(f"Unsupported mode {mode!r}, please select one of: {list(_SUPPORTED_MODES)}.") self._iterables = iterables self._flattened, self._spec = _tree_flatten(iterables) self._mode = mode self._iterator: Optional[_ModeIterator] = None self._limits: Optional[List[Union[int, float]]] = None @property def iterables(self) -> Any: """Return the original collection of iterables.""" return self._iterables @property def sampler(self) -> Any: """Return a collections of samplers extracted from iterables.""" return _map_and_unflatten(lambda x: getattr(x, "sampler", None), self.flattened, self._spec) @property def batch_sampler(self) -> Any: """Return a collections of batch samplers extracted from iterables.""" return _map_and_unflatten(lambda x: getattr(x, "batch_sampler", None), self.flattened, self._spec) @property def flattened(self) -> List[Any]: """Return the flat list of iterables.""" return self._flattened @flattened.setter def flattened(self, flattened: List[Any]) -> None: """Setter to conveniently update the list of iterables.""" if len(flattened) != len(self._flattened): raise ValueError( f"Mismatch in flattened length ({len(flattened)}) and existing length ({len(self._flattened)})" ) # update the iterable collection self._iterables = tree_unflatten(flattened, self._spec) self._flattened = flattened @property def limits(self) -> Optional[List[Union[int, float]]]: """Optional limits per iterator.""" return self._limits @limits.setter def limits(self, limits: Optional[Union[int, float, List[Union[int, float]]]]) -> None: if isinstance(limits, (int, float)): limits = [limits] * len(self.flattened) elif isinstance(limits, list) and len(limits) != len(self.flattened): raise ValueError( f"Mismatch in number of limits ({len(limits)}) and number of iterables ({len(self.flattened)})" ) self._limits = limits def __next__(self) -> _ITERATOR_RETURN: assert self._iterator is not None out = next(self._iterator) if isinstance(self._iterator, _Sequential): return out out, batch_idx, dataloader_idx = out return tree_unflatten(out, self._spec), batch_idx, dataloader_idx @override def __iter__(self) -> Self: cls = _SUPPORTED_MODES[self._mode]["iterator"] iterator = cls(self.flattened, self._limits) iter(iterator) self._iterator = iterator return self def __len__(self) -> int: """Compute the number of batches.""" if self._iterator is None: raise RuntimeError("Please call `iter(combined_loader)` first.") return len(self._iterator)
[docs] def reset(self) -> None: """Reset the state and shutdown any workers.""" if self._iterator is not None: self._iterator.reset() self._iterator = None for iterable in self.flattened: _shutdown_workers_and_reset_iterator(iterable)
def _dataset_length(self) -> int: """Compute the total length of the datasets according to the current mode.""" datasets = [getattr(dl, "dataset", None) for dl in self.flattened] lengths = [length for ds in datasets if (length := sized_len(ds)) is not None] if not lengths: raise NotImplementedError("All datasets are iterable-style datasets.") fn = _SUPPORTED_MODES[self._mode]["fn"] return fn(lengths) def _state_dicts(self) -> List[Dict[str, Any]]: """Returns the list of state dicts for iterables in `self.flattened` that are stateful.""" return [loader.state_dict() for loader in self.flattened if isinstance(loader, _Stateful)] def _load_state_dicts(self, states: List[Dict[str, Any]]) -> None: """Loads the state dicts for iterables in `self.flattened` that are stateful.""" if not states: return stateful_loaders = [loader for loader in self.flattened if isinstance(loader, _Stateful)] if len(stateful_loaders) != len(states): raise RuntimeError( f"The CombinedLoader has {len(stateful_loaders)} stateful loaders, but found {len(states)} states" " in the checkpoint. Please make sure you define the same dataloaders that were used when saving" " the checkpoint." ) for loader, state_dict in zip(stateful_loaders, states): loader.load_state_dict(state_dict)
def _shutdown_workers_and_reset_iterator(dataloader: object) -> None: if hasattr(dataloader, "_iterator"): if isinstance(dataloader._iterator, _MultiProcessingDataLoaderIter): dataloader._iterator._shutdown_workers() dataloader._iterator = None def _get_iterables_lengths(iterables: List[Iterable]) -> List[Union[int, float]]: return [(float("inf") if (length := sized_len(iterable)) is None else length) for iterable in iterables]