LightningDataModule
A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:
A datamodule encapsulates the five steps involved in data processing in PyTorch:
Download / tokenize / process.
Clean and (maybe) save to disk.
Load inside
Dataset
.Apply transforms (rotate, tokenize, etc…).
Wrap inside a
DataLoader
.
This class can then be shared and used anywhere:
from pl_bolts.datamodules import CIFAR10DataModule, ImagenetDataModule
model = LitClassifier()
trainer = Trainer()
imagenet = ImagenetDataModule()
trainer.fit(model, imagenet)
cifar10 = CIFAR10DataModule()
trainer.fit(model, cifar10)
Why do I need a DataModule?
In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.
Datamodules are for you if you ever asked the questions:
what splits did you use?
what transforms did you use?
what normalization did you use?
how did you prepare/tokenize the data?
What is a DataModule
A DataModule is simply a collection of a train_dataloader(s), val_dataloader(s), test_dataloader(s) along with the matching transforms and data processing/downloads steps required.
Here’s a simple PyTorch example:
# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: Optional[str] = None):
self.mnist_test = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def teardown(self, stage: Optional[str] = None):
# Used to clean-up when the run is finished
...
But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.
mnist = MNISTDataModule(my_path)
model = LitClassifier()
trainer = Trainer()
trainer.fit(model, mnist)
Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# self.dims is returned when you call dm.size()
# Setting default dims here because we know them.
# Could optionally be assigned dynamically in dm.setup()
self.dims = (1, 28, 28)
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None):
# Assign train/val datasets for use in dataloaders
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
# Optionally...
# self.dims = tuple(self.mnist_train[0][0].shape)
# Assign test dataset for use in dataloader(s)
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
# Optionally...
# self.dims = tuple(self.mnist_test[0][0].shape)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
LightningDataModule API
To define a DataModule define 5 methods:
prepare_data (how to download(), tokenize, etc…)
setup (how to split, etc…)
train_dataloader
val_dataloader(s)
test_dataloader(s)
and optionally one or multiple predict_dataloader(s).
prepare_data
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
download
tokenize
etc…
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
Warning
prepare_data
is called from a single process (e.g. GPU 0). Do not use it to assign state (self.x = y).
setup
There are also data operations you might want to perform on every GPU. Use setup to do things like:
count number of classes
build vocabulary
perform train/val/test splits
apply transforms (defined explicitly in your datamodule or assigned in init)
etc…
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: Optional[str] = None):
# Assign Train/val split(s) for use in Dataloaders
if stage in (None, "fit"):
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
self.dims = self.mnist_train[0][0].shape
# Assign Test split(s) for use in Dataloaders
if stage in (None, "test"):
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
self.dims = getattr(self, "dims", self.mnist_test[0][0].shape)
setup()
expects an stage: Optional[str]
argument.
It is used to separate setup logic for trainer.{fit,validate,test}
. If setup
is called with stage = None
,
we assume all stages have been set-up.
Note
setup
is called from every process. Setting state here is okay.
Note
teardown
can be used to clean up the state. It is also called from every process
Note
{setup,teardown,prepare_data}
call will be only called once for a specific stage.
If the stage was None
then we assume {fit,validate,test}
have been called. For example, this means that
any duplicate dm.setup('fit')
calls will be a no-op. To avoid this, you can overwrite
dm._has_setup_fit = False
train_dataloader
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup
.
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
val_dataloader
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup
.
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
test_dataloader
Use this method to generate the test dataloader. Usually you just wrap the dataset you defined in setup
.
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
predict_dataloader
Returns a special dataloader for inference. This is the dataloader that the Trainer
predict()
method uses.
import pytorch_lightning as pl
class MNISTDataModule(pl.LightningDataModule):
def predict_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
transfer_batch_to_device
Override to define how you want to move an arbitrary batch to a device.
To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.
class MNISTDataModule(LightningDataModule):
def transfer_batch_to_device(self, batch, device, dataloader_idx):
x = batch["x"]
x = CustomDataWrapper(x)
batch["x"] = x.to(device)
return batch
Note
This hook only runs on single GPU training and DDP (no data-parallel).
on_before_batch_transfer
Override to alter or apply augmentations to your batch before it is transferred to the device.
To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.
class MNISTDataModule(LightningDataModule):
def on_before_batch_transfer(self, batch, dataloader_idx):
batch["x"] = transforms(batch["x"])
return batch
Note
This hook only runs on single GPU training and DDP (no data-parallel).
on_after_batch_transfer
Override to alter or apply augmentations to your batch after it is transferred to the device.
To check the current state of execution of this hook you can use self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.
class MNISTDataModule(LightningDataModule):
def on_after_batch_transfer(self, batch, dataloader_idx):
batch["x"] = gpu_transforms(batch["x"])
return batch
Note
This hook only runs on single GPU training and DDP (no data-parallel). This hook
will also be called when using CPU device, so adding augmentations here or in
on_before_batch_transfer
means the same thing.
Note
To decouple your data from transforms you can parametrize them via __init__
.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, train_transforms, val_transforms, test_transforms):
super().__init__()
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
Using a DataModule
The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, dm)
trainer.test(datamodule=dm)
If you need information from the dataset to build your model, then run
prepare_data()
and
setup()
manually (Lightning ensures
the method runs on the correct devices).
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
dm.setup(stage="test")
trainer.test(datamodule=dm)
DataModules without Lightning
You can of course use DataModules in plain PyTorch code as well.
# download, etc...
dm = MNISTDataModule()
dm.prepare_data()
# splits/transforms
dm.setup(stage="fit")
# use data
for batch in dm.train_dataloader():
...
for batch in dm.val_dataloader():
...
dm.teardown(stage="fit")
# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
...
dm.teardown(stage="test")
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.