LightningDataModule

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:

A datamodule encapsulates the five steps involved in data processing in PyTorch:

  1. Download / tokenize / process.

  2. Clean and (maybe) save to disk.

  3. Load inside Dataset.

  4. Apply transforms (rotate, tokenize, etc…).

  5. Wrap inside a DataLoader.


This class can then be shared and used anywhere:

model = LitClassifier()
trainer = Trainer()

imagenet = ImagenetDataModule()
trainer.fit(model, datamodule=imagenet)

cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)

Why do I need a DataModule?

In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.

Datamodules are for you if you ever asked the questions:

  • what splits did you use?

  • what transforms did you use?

  • what normalization did you use?

  • how did you prepare/tokenize the data?


What is a DataModule?

The LightningDataModule is a convenient way to manage data in PyTorch Lightning. It encapsulates training, validation, testing, and prediction dataloaders, as well as any necessary steps for data processing, downloads, and transformations. By using a LightningDataModule, you can easily develop dataset-agnostic models, hot-swap different datasets, and share data splits and transformations across projects.

Here’s a simple PyTorch example:

# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)

The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=self.batch_size)

    def teardown(self, stage: str):
        # Used to clean-up when the run is finished
        ...

But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.

mnist = MNISTDataModule(my_path)
model = LitClassifier()

trainer = Trainer()
trainer.fit(model, mnist)

Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.

import lightning as L
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)

LightningDataModule API

To define a DataModule the following methods are used to create train/val/test/predict dataloaders:

prepare_data

Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning ensures the prepare_data() is called only within a single process on CPU, so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook depends upon prepare_data_per_node. setup() is called after prepare_data and there is a barrier in between which ensures that all the processes proceed to setup once the data is prepared and available for use.

  • download, i.e. download data only once on the disk from a single process

  • tokenize. Since it’s a one time process, it is not recommended to do it on all processes

  • etc…

class MNISTDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

Warning

prepare_data is called from the main process. It is not recommended to assign state here (e.g. self.x = y) since it is called on a single process and if you assign states here then they won’t be available for other processes.

setup

There are also data operations you might want to perform on every GPU. Use setup() to do things like:

  • count number of classes

  • build vocabulary

  • perform train/val/test splits

  • create datasets

  • apply transforms (defined explicitly in your datamodule)

  • etc…

import lightning as L


class MNISTDataModule(L.LightningDataModule):
    def setup(self, stage: str):
        # Assign Train/val split(s) for use in Dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(
                mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
            )

        # Assign Test split(s) for use in Dataloaders
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)

For eg., if you are working with NLP task where you need to tokenize the text and use it, then you can do something like as follows:

class LitDataModule(L.LightningDataModule):
    def prepare_data(self):
        dataset = load_Dataset(...)
        train_dataset = ...
        val_dataset = ...
        # tokenize
        # save it to disk

    def setup(self, stage):
        # load it back here
        dataset = load_dataset_from_disk(...)

This method expects a stage argument. It is used to separate setup logic for trainer.{fit,validate,test,predict}.

Note

setup is called from every process across all the nodes. Setting state here is recommended.

Note

teardown can be used to clean up the state. It is also called from every process across all the nodes.

train_dataloader

Use the train_dataloader() method to generate the training dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() method uses.

import lightning as L


class MNISTDataModule(L.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)

val_dataloader

Use the val_dataloader() method to generate the validation dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer fit() and validate() methods uses.

import lightning as L


class MNISTDataModule(L.LightningDataModule):
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)

test_dataloader

Use the test_dataloader() method to generate the test dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer test() method uses.

import lightning as L


class MNISTDataModule(L.LightningDataModule):
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

predict_dataloader

Use the predict_dataloader() method to generate the prediction dataloader(s). Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer predict() method uses.

import lightning as L


class MNISTDataModule(L.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=64)

transfer_batch_to_device

LightningDataModule.transfer_batch_to_device(batch, device, dataloader_idx)

Override this hook if your DataLoader returns tensors wrapped in a custom data structure.

The data types listed below (and any arbitrary nesting of them) are supported out of the box:

For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).

Note

This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:
  • batch (Any) – A batch of data that needs to be transferred to a new device.

  • device (device) – The target device as defined in PyTorch.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type:

Any

Returns:

A reference to the data on the new device.

Example:

def transfer_batch_to_device(self, batch, device, dataloader_idx):
    if isinstance(batch, CustomBatch):
        # move all tensors in your custom data structure to the device
        batch.samples = batch.samples.to(device)
        batch.targets = batch.targets.to(device)
    elif dataloader_idx == 0:
        # skip device transfer for the first dataloader or anything you wish
        pass
    else:
        batch = super().transfer_batch_to_device(batch, device, dataloader_idx)
    return batch

See also

  • move_data_to_device()

  • apply_to_collection()

on_before_batch_transfer

LightningDataModule.on_before_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch before it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type:

Any

Returns:

A batch of data

Example:

def on_before_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = transforms(batch['x'])
    return batch

See also

  • on_after_batch_transfer()

  • transfer_batch_to_device()

on_after_batch_transfer

LightningDataModule.on_after_batch_transfer(batch, dataloader_idx)

Override to alter or apply batch augmentations to your batch after it is transferred to the device.

Note

To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting so that you can add different logic as per your requirement.

Parameters:
  • batch (Any) – A batch of data that needs to be altered or augmented.

  • dataloader_idx (int) – The index of the dataloader to which the batch belongs.

Return type:

Any

Returns:

A batch of data

Example:

def on_after_batch_transfer(self, batch, dataloader_idx):
    batch['x'] = gpu_transforms(batch['x'])
    return batch

See also

  • on_before_batch_transfer()

  • transfer_batch_to_device()

load_state_dict

LightningDataModule.load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict (dict[str, Any]) – the datamodule state returned by state_dict.

Return type:

None

state_dict

LightningDataModule.state_dict()[source]

Called when saving a checkpoint, implement to generate and save datamodule state.

Return type:

dict[str, Any]

Returns:

A dictionary containing datamodule state.

teardown

LightningDataModule.teardown(stage)

Called at the end of fit (train + validate), validate, test, or predict.

Parameters:

stage (str) – either 'fit', 'validate', 'test', or 'predict'

Return type:

None

prepare_data_per_node

If set to True will call prepare_data() on LOCAL_RANK=0 for every node. If set to False will only call from NODE_RANK=0, LOCAL_RANK=0.

class LitDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
        self.prepare_data_per_node = True

Using a DataModule

The recommended way to use a DataModule is simply:

dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)

If you need information from the dataset to build your model, then run prepare_data and setup manually (Lightning ensures the method runs on the correct devices).

dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

You can access the current used datamodule of a trainer via trainer.datamodule and the current used dataloaders via the trainer properties train_dataloader(), val_dataloaders(), test_dataloaders(), and predict_dataloaders().


DataModules without Lightning

You can of course use DataModules in plain PyTorch code as well.

# download, etc...
dm = MNISTDataModule()
dm.prepare_data()

# splits/transforms
dm.setup(stage="fit")

# use data
for batch in dm.train_dataloader():
    ...

for batch in dm.val_dataloader():
    ...

dm.teardown(stage="fit")

# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...

dm.teardown(stage="test")

But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.


Hyperparameters in DataModules

Like LightningModules, DataModules support hyperparameters with the same API.

import lightning as L


class CustomDataModule(L.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

    def configure_optimizers(self):
        # access the saved hyperparameters
        opt = optim.Adam(self.parameters(), lr=self.hparams.lr)

Refer to save_hyperparameters in lightning module for more details.


Save DataModule state

When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.

import lightning as L


class LitDataModule(L.LightningDataModule):
    def state_dict(self):
        # track whatever you want here
        state = {"current_train_batch_index": self.current_train_batch_index}
        return state

    def load_state_dict(self, state_dict):
        # restore the state based on what you tracked in (def state_dict)
        self.current_train_batch_index = state_dict["current_train_batch_index"]