• Docs >
  • Fault-tolerant Training
Shortcuts

Fault-tolerant Training

Warning

Fault-tolerant Training is currently an experimental feature within Lightning.

Fault-tolerant Training is an internal mechanism that enables PyTorch Lightning to recover from a hardware or software failure. This is particularly interesting while training in the cloud with preemptive instances which can shutdown at any time.

Until now, a Trainer.fit() failing in the middle of an epoch during training or validation would require the user to restart that epoch completely, losing any progress made during the epoch. This would make benchmarking non-reproducible as optimization has been interrupted and only partially restored.

With Fault Tolerant Training, when Trainer.fit() fails in the middle of an epoch during training or validation, Lightning will restart exactly where it failed, and everything will be restored.

Fault tolerance can be enabled as follows:

PL_FAULT_TOLERANT_TRAINING=1 python script.py

Under The Hood

Lightning keeps track of the following state updates during training:

  • Samplers indices and random states across multiple processes and workers: This enables restoring random transforms and batch fetching to the exact state as it was right before the failure.

  • Optimizers, learning rate schedulers, callbacks, etc..

  • Loop progression

  • Logging internal states such that metric reductions on epoch end are not getting affected by the failure and model selection can continue as expected.

Currently Supported

If you are using a single map-based dataset by sub-classing Dataset, everything should work as expected.

from torch.utils.data import Dataset, DataLoader


class RandomDataset(Dataset):
    def __init__(self, size: int, length: int):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len

If you are using a single iterable-based dataset, there are some limitations. To support fault-tolerance, you will need to use and expose a sampler within your dataset.

For example, the following implementation for an iterable dataset sub-classing IterableDataset won’t be supported.

from torch.utils.data import IterableDataset, DataLoader


# does not support fault tolerance training!
class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, count: int):
        self.count = count
        self.size = size

    def __iter__(self):
        for _ in range(self.count):
            yield torch.randn(self.size)

There are two primary reasons why Lightning can’t support the previous implementation.

  • Lightning cannot infer what you are iterating over, making it difficult to restart training. Lightning Fault Tolerant Training requires a Sampler to be used to encapsulate the fetching logic, requiring both the sampler and an iterator to be made available as attributes within the dataset, so Lightning can access them to track progress.

  • Implementing the __next__ method is required as it separates iterator creation from its consumption, which is essential for Lightning to wrap the iterator before their consumption.

If your iterable dataset are implemented in the following way, everything should works as expected.

import torch
from torch.utils.data import IterableDataset, DataLoader


class RandomIterableDataset(IterableDataset):
    def __init__(self, size: int, length: int):
        self.data = torch.randn(length, size)

        # expose the sampler as an attribute
        self.sampler = RandomSampler(range(length))

    def __iter__(self) -> "RandomIterableDataset":
        # expose the generator from the sampler as an attribute
        # the ``sampler_iter`` will be wrapped by Lightning to ensure
        # we can capture random seeds and iteration count for fast-forward samplers
        # while restarting.
        self.sampler_iter = iter(self.sampler)
        return self

    def __next__(self) -> torch.Tensor:
        # call next on the iterator and get the associated data.
        # the logic here can become more complex but the sampler
        # should be the central piece for fetching the next sample
        index = next(self.sampler_iter)
        return self.data[index]

Current Known Limitations

If you are using multiple training dataloaders, Lightning won’t be able to restore the random state properly.

class LitModel(LightningModule):
    def train_dataloader(self):
        loader_a = torch.utils.data.DataLoader(range(8), batch_size=4)
        loader_b = torch.utils.data.DataLoader(range(16), batch_size=4)
        return {"loader_a": loader_a, "loader_b": loader_b}

    def training_step(self, batch, batch_idx):
        # access the data in the same format as the collection of dataloaders.
        # dict, list are supported.
        loader_a = batch["loader_a"]
        loader_b = batch["loader_b"]

If you believe this to be useful, please open a feature request.

Performance Impacts

Fault-tolerant Training was tested on common and worst-case scenarios in order to measure the impact of the internal state tracking on the total training time. On tiny models like the BoringModel and RandomDataset which has virtually no data loading and processing overhead, we noticed up to 50% longer training time with fault tolerance enabled. In this worst-case scenario, fault-tolerant adds an overhead that is noticeable in comparison to the compute time for dataloading itself. However, for more realistic training workloads where data loading and preprocessing is more expensive, the constant overhead that fault tolerance adds becomes less noticeable or not noticeable at all. For example, when training with ResNet50 on CIFAR 10 we have observed a 0.5% to 1% increase in training time depending on batch size or number of workers.

More detailed benchmarks will be shared in the future.

Note

The extra time is coming from several parts:

  • Capturing the iteration count + random states for each sample within each DataLoader workers and pass it through the data_queue

  • Extra logic to handle / store the dataloader’s states from each batch.