• Docs >
  • Fault-tolerant Training (FAQ)

Fault-tolerant Training (FAQ)

How do I use iterable datasets?

To support fault-tolerance, you will need to use and expose a sampler within your dataset.

For example, the following implementation for an iterable dataset sub-classing IterableDataset won’t be supported.

from torch.utils.data import IterableDataset, DataLoader

# does not support fault tolerance training!
class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(self.count):
            yield torch.randn(self.size)

There are two primary reasons why Lightning can’t support the previous implementation.

  • Lightning cannot infer what you are iterating over, making it difficult to restart training. Lightning Fault Tolerant Training requires a Sampler to be used to encapsulate the fetching logic, requiring both the sampler and an iterator to be made available as attributes within the dataset, so Lightning can access them to track progress.

  • Implementing the __next__ method is required as it separates iterator creation from its consumption, which is essential for Lightning to wrap the iterator before their consumption.

If your iterable dataset are implemented in the following way, everything should works as expected.

import torch
from torch.utils.data import IterableDataset, DataLoader

class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, length: int):
        self.data = torch.randn(length, size)

        # expose the sampler as an attribute
        self.sampler = RandomSampler(range(length))

    def __iter__(self) -> "RandomIterableDataset":
        # expose the generator from the sampler as an attribute
        # the ``sampler_iter`` will be wrapped by Lightning to ensure
        # we can capture random seeds and iteration count for fast-forward samplers
        # while restarting.
        self.sampler_iter = iter(self.sampler)
        return self

    def __next__(self) -> torch.Tensor:
        # call next on the iterator and get the associated data.
        # the logic here can become more complex but the sampler
        # should be the central piece for fetching the next sample
        index = next(self.sampler_iter)
        return self.data[index]

How do I use multiple dataloaders?

If you are using multiple training dataloaders, Lightning won’t be able to restore the random state properly.

class LitModel(LightningModule):
    def train_dataloader(self):
        loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
        loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
        return {"loader_a": loader_a, "loader_b": loader_b}

    def training_step(self, batch, batch_idx):
        # access the data in the same format as the collection of dataloaders.
        # dict, list are supported.
        loader_a = batch["loader_a"]
        loader_b = batch["loader_b"]

If you believe this to be useful, please open a feature request.

What are the performance impacts?

Fault-tolerant Training was tested on common and worst-case scenarios in order to measure the impact of the internal state tracking on the total training time. On tiny models like the BoringModel and RandomDataset which has virtually no data loading and processing overhead, we noticed up to 50% longer training time with fault tolerance enabled. In this worst-case scenario, fault-tolerant adds an overhead that is noticeable in comparison to the compute time for dataloading itself. However, for more realistic training workloads where data loading and preprocessing is more expensive, the constant overhead that fault tolerance adds becomes less noticeable or not noticeable at all. For example, when training with ResNet50 on CIFAR 10 we have observed a 0.5% to 1% increase in training time depending on batch size or number of workers.

More detailed benchmarks will be shared in the future.


The extra time is coming from several parts:

  • Capturing the iteration count + random states for each sample within each DataLoader workers and pass it through the data_queue

  • Extra logic to handle / store the dataloader’s states from each batch.

What happens to my shuffled dataset?

If you are using a single map-based dataset by sub-classing Dataset, everything should work as expected.

from torch.utils.data import Dataset, DataLoader

class RandomDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

What parts are fault-tolerant?

Lightning keeps track of the following state updates during training:

  • Samplers indices and random states across multiple processes and workers: This enables restoring random transforms and batch fetching to the exact state as it was right before the failure.

  • Optimizers, learning rate schedulers, callbacks, etc..

  • Loop progression

  • Logging internal states such that metric reductions on epoch end are not getting affected by the failure and model selection can continue as expected.

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.