Shortcuts

LightningDataModule

A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:


A datamodule encapsulates the five steps involved in data processing in PyTorch:

  1. Download / tokenize / process.

  2. Clean and (maybe) save to disk.

  3. Load inside Dataset.

  4. Apply transforms (rotate, tokenize, etc…).

  5. Wrap inside a DataLoader.


This class can then be shared and used anywhere:

from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule

model = LitClassifier()
trainer = Trainer()

imagenet = ImagenetDataModule()
trainer.fit(model, imagenet)

cifar10 = CIFAR10DataModule()
trainer.fit(model, cifar10)

Why do I need a DataModule?

In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.

Datamodules are for you if you ever asked the questions:

  • what splits did you use?

  • what transforms did you use?

  • what normalization did you use?

  • how did you prepare/tokenize the data?


What is a DataModule

A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required.

Here’s a simple PyTorch example:

# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)

The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: Optional[str] = None):
        self.mnist_test = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=self.batch_size)

    def teardown(self, stage: Optional[str] = None):
        # Used to clean-up when the run is finished
        ...

But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.

mnist = MNISTDataModule(my_path)
model = LitClassifier()

trainer = Trainer()
trainer.fit(model, mnist)

Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.

import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader

# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms


class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # Setting default dims here because we know them.
        # Could optionally be assigned dynamically in dm.setup()
        self.dims = (1, 28, 28)

    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)

    def setup(self, stage: Optional[str] = None):

        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

            # Optionally...
            # self.dims = tuple(self.mnist_train[0][0].shape)

        # Assign test dataset for use in dataloader(s)
        if stage == "test" or stage is None:
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

            # Optionally...
            # self.dims = tuple(self.mnist_test[0][0].shape)

    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

LightningDataModule API

To define a DataModule define 5 methods:

  • prepare_data (how to download(), tokenize, etc…)

  • setup (how to split, etc…)

  • train_dataloader

  • val_dataloader(s)

  • test_dataloader(s)

and optionally one or multiple predict_dataloader(s).

prepare_data

Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.

  • download

  • tokenize

  • etc…

class MNISTDataModule(pl.LightningDataModule):
    def prepare_data(self):
        # download
        MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
        MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())

Warning

prepare_data is called from a single process (e.g. GPU 0). Do not use it to assign state (self.x = y).

setup

There are also data operations you might want to perform on every GPU. Use setup to do things like:

  • count number of classes

  • build vocabulary

  • perform train/val/test splits

  • apply transforms (defined explicitly in your datamodule)

  • etc…

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def setup(self, stage: Optional[str] = None):

        # Assign Train/val split(s) for use in Dataloaders
        if stage in (None, "fit"):
            mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
            self.dims = self.mnist_train[0][0].shape

        # Assign Test split(s) for use in Dataloaders
        if stage in (None, "test"):
            self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
            self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)

setup() expects an stage: Optional[str] argument. It is used to separate setup logic for trainer.{fit,validate,test}. If setup is called with stage = None, we assume all stages have been set-up.

Note

setup is called from every process. Setting state here is okay.

Note

teardown can be used to clean up the state. It is also called from every process

Note

{setup,teardown,prepare_data} call will be only called once for a specific stage. If the stage was None then we assume {fit,validate,test} have been called. For example, this means that any duplicate dm.setup('fit') calls will be a no-op. To avoid this, you can overwrite dm._has_setup_fit = False

train_dataloader

Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=64)

val_dataloader

Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=64)

test_dataloader

Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in setup.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

predict_dataloader

Returns a special dataloader for inference. This is the dataloader that the Trainer predict() method uses.

import pytorch_lightning as pl


class MNISTDataModule(pl.LightningDataModule):
    def predict_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=64)

transfer_batch_to_device

Override to define how you want to move an arbitrary batch to a device. To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting/sanity_checking so that you can add different logic as per your requirement.

class MNISTDataModule(LightningDataModule):
    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        x = batch["x"]
        x = CustomDataWrapper(x)
        batch["x"] = x.to(device)
        return batch

Note

This hook only runs on single GPU training and DDP (no data-parallel).

on_before_batch_transfer

Override to alter or apply augmentations to your batch before it is transferred to the device. To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting/sanity_checking so that you can add different logic as per your requirement.

class MNISTDataModule(LightningDataModule):
    def on_before_batch_transfer(self, batch, dataloader_idx):
        batch["x"] = transforms(batch["x"])
        return batch

Note

This hook only runs on single GPU training and DDP (no data-parallel).

on_after_batch_transfer

Override to alter or apply augmentations to your batch after it is transferred to the device. To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting/sanity_checking so that you can add different logic as per your requirement.

class MNISTDataModule(LightningDataModule):
    def on_after_batch_transfer(self, batch, dataloader_idx):
        batch["x"] = gpu_transforms(batch["x"])
        return batch

Note

This hook only runs on single GPU training and DDP (no data-parallel). This hook will also be called when using CPU device, so adding augmentations here or in on_before_batch_transfer means the same thing.

Note

To decouple your data from transforms you can parametrize them via __init__.

class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, train_transforms, val_transforms, test_transforms):
        super().__init__()
        self.train_transforms = train_transforms
        self.val_transforms = val_transforms
        self.test_transforms = test_transforms

Using a DataModule

The recommended way to use a DataModule is simply:

dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)
trainer.test(datamodule=dm)

If you need information from the dataset to build your model, then run prepare_data() and setup() manually (Lightning ensures the method runs on the correct devices).

dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")

model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)

dm.setup(stage="test")
trainer.test(datamodule=dm)

DataModules without Lightning

You can of course use DataModules in plain PyTorch code as well.

# download, etc...
dm = MNISTDataModule()
dm.prepare_data()

# splits/transforms
dm.setup(stage="fit")

# use data
for batch in dm.train_dataloader():
    ...
for batch in dm.val_dataloader():
    ...

dm.teardown(stage="fit")

# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
    ...

dm.teardown(stage="test")

But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.


Hyperparameters in DataModules

Like LightningModules, DataModules support hyperparameters with the same API.

import pytorch_lightning as pl


class CustomDataModule(pl.LightningDataModule):
    def __init__(self, *args, **kwargs):
        super().__init__()
        self.save_hyperparameters()

Refer to save_hyperparameters in lightning module for more details.