LightningDataModule
A datamodule is a shareable, reusable class that encapsulates all the steps needed to process data:
A datamodule encapsulates the five steps involved in data processing in PyTorch:
Download / tokenize / process.
Clean and (maybe) save to disk.
Load inside
Dataset
.Apply transforms (rotate, tokenize, etc…).
Wrap inside a
DataLoader
.
This class can then be shared and used anywhere:
model = LitClassifier()
trainer = Trainer()
imagenet = ImagenetDataModule()
trainer.fit(model, datamodule=imagenet)
cifar10 = CIFAR10DataModule()
trainer.fit(model, datamodule=cifar10)
Why do I need a DataModule?
In normal PyTorch code, the data cleaning/preparation is usually scattered across many files. This makes sharing and reusing the exact splits and transforms across projects impossible.
Datamodules are for you if you ever asked the questions:
what splits did you use?
what transforms did you use?
what normalization did you use?
how did you prepare/tokenize the data?
What is a DataModule?
The LightningDataModule
is a convenient way to manage data in PyTorch Lightning.
It encapsulates training, validation, testing, and prediction dataloaders, as well as any necessary steps for data processing,
downloads, and transformations. By using a LightningDataModule
, you can
easily develop dataset-agnostic models, hot-swap different datasets, and share data splits and transformations across projects.
Here’s a simple PyTorch example:
# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])
train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
The equivalent DataModule just organizes the same exact code, but makes it reusable across projects.
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=self.batch_size)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=self.batch_size)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=self.batch_size)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=self.batch_size)
def teardown(self, stage: str):
# Used to clean-up when the run is finished
...
But now, as the complexity of your processing grows (transforms, multiple-GPU training), you can let Lightning handle those details for you while making this dataset reusable so you can share with colleagues or use in different projects.
mnist = MNISTDataModule(my_path)
model = LitClassifier()
trainer = Trainer()
trainer.fit(model, mnist)
Here’s a more realistic, complex DataModule that shows how much more reusable the datamodule is.
import lightning.pytorch as pl
from torch.utils.data import random_split, DataLoader
# Note - you must have torchvision installed for this example
from torchvision.datasets import MNIST
from torchvision import transforms
class MNISTDataModule(pl.LightningDataModule):
def __init__(self, data_dir: str = "./"):
super().__init__()
self.data_dir = data_dir
self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
def prepare_data(self):
# download
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage: str):
# Assign train/val datasets for use in dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign test dataset for use in dataloader(s)
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
if stage == "predict":
self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=32)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=32)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=32)
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=32)
LightningDataModule API
To define a DataModule the following methods are used to create train/val/test/predict dataloaders:
prepare_data (how to download, tokenize, etc…)
setup (how to split, define dataset, etc…)
prepare_data
Downloading and saving data with multiple processes (distributed settings) will result in corrupted data. Lightning
ensures the prepare_data()
is called only within a single process on CPU,
so you can safely add your downloading logic within. In case of multi-node training, the execution of this hook
depends upon prepare_data_per_node. setup()
is called after
prepare_data
and there is a barrier in between which ensures that all the processes proceed to setup
once the data is prepared and available for use.
download, i.e. download data only once on the disk from a single process
tokenize. Since it’s a one time process, it is not recommended to do it on all processes
etc…
class MNISTDataModule(pl.LightningDataModule):
def prepare_data(self):
# download
MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
MNIST(os.getcwd(), train=False, download=True, transform=transforms.ToTensor())
Warning
prepare_data
is called from the main process. It is not recommended to assign state here (e.g. self.x = y
) since it is called on a single process and if you assign
states here then they won’t be available for other processes.
setup
There are also data operations you might want to perform on every GPU. Use setup()
to do things like:
count number of classes
build vocabulary
perform train/val/test splits
create datasets
apply transforms (defined explicitly in your datamodule)
etc…
import lightning.pytorch as pl
class MNISTDataModule(pl.LightningDataModule):
def setup(self, stage: str):
# Assign Train/val split(s) for use in Dataloaders
if stage == "fit":
mnist_full = MNIST(self.data_dir, train=True, download=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
# Assign Test split(s) for use in Dataloaders
if stage == "test":
self.mnist_test = MNIST(self.data_dir, train=False, download=True, transform=self.transform)
For eg., if you are working with NLP task where you need to tokenize the text and use it, then you can do something like as follows:
class LitDataModule(LightningDataModule):
def prepare_data(self):
dataset = load_Dataset(...)
train_dataset = ...
val_dataset = ...
# tokenize
# save it to disk
def setup(self, stage):
# load it back here
dataset = load_dataset_from_disk(...)
This method expects a stage
argument.
It is used to separate setup logic for trainer.{fit,validate,test,predict}
.
Note
setup is called from every process across all the nodes. Setting state here is recommended.
Note
teardown can be used to clean up the state. It is also called from every process across all the nodes.
train_dataloader
Use the train_dataloader()
method to generate the training dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
fit()
method uses.
import lightning.pytorch as pl
class MNISTDataModule(pl.LightningDataModule):
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=64)
val_dataloader
Use the val_dataloader()
method to generate the validation dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
fit()
and validate()
methods uses.
import lightning.pytorch as pl
class MNISTDataModule(pl.LightningDataModule):
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=64)
test_dataloader
Use the test_dataloader()
method to generate the test dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
test()
method uses.
import lightning.pytorch as pl
class MNISTDataModule(pl.LightningDataModule):
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=64)
predict_dataloader
Use the predict_dataloader()
method to generate the prediction dataloader(s).
Usually you just wrap the dataset you defined in setup. This is the dataloader that the Trainer
predict()
method uses.
import lightning.pytorch as pl
class MNISTDataModule(pl.LightningDataModule):
def predict_dataloader(self):
return DataLoader(self.mnist_predict, batch_size=64)
transfer_batch_to_device
- LightningDataModule.transfer_batch_to_device(batch, device, dataloader_idx)
Override this hook if your
DataLoader
returns tensors wrapped in a custom data structure.The data types listed below (and any arbitrary nesting of them) are supported out of the box:
torch.Tensor
or anything that implements .to(…)
For anything else, you need to define how the data is moved to the target device (CPU, GPU, TPU, …).
Note
This hook should only transfer the data and not modify it, nor should it move the data to any other device than the one passed in as argument (unless you know what you are doing). To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
A reference to the data on the new device.
Example:
def transfer_batch_to_device(self, batch, device, dataloader_idx): if isinstance(batch, CustomBatch): # move all tensors in your custom data structure to the device batch.samples = batch.samples.to(device) batch.targets = batch.targets.to(device) elif dataloader_idx == 0: # skip device transfer for the first dataloader or anything you wish pass else: batch = super().transfer_batch_to_device(batch, device, dataloader_idx) return batch
- Raises:
MisconfigurationException – If using IPUs,
Trainer(accelerator='ipu')
.
See also
move_data_to_device()
apply_to_collection()
on_before_batch_transfer
- LightningDataModule.on_before_batch_transfer(batch, dataloader_idx)
Override to alter or apply batch augmentations to your batch before it is transferred to the device.
Note
To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
A batch of data
Example:
def on_before_batch_transfer(self, batch, dataloader_idx): batch['x'] = transforms(batch['x']) return batch
See also
on_after_batch_transfer()
transfer_batch_to_device()
on_after_batch_transfer
- LightningDataModule.on_after_batch_transfer(batch, dataloader_idx)
Override to alter or apply batch augmentations to your batch after it is transferred to the device.
Note
To check the current state of execution of this hook you can use
self.trainer.training/testing/validating/predicting
so that you can add different logic as per your requirement.- Parameters:
- Return type:
- Returns:
A batch of data
Example:
def on_after_batch_transfer(self, batch, dataloader_idx): batch['x'] = gpu_transforms(batch['x']) return batch
- Raises:
MisconfigurationException – If using IPUs,
Trainer(accelerator='ipu')
.
See also
on_before_batch_transfer()
transfer_batch_to_device()
load_state_dict
state_dict
teardown
prepare_data_per_node
If set to True
will call prepare_data()
on LOCAL_RANK=0 for every node.
If set to False
will only call from NODE_RANK=0, LOCAL_RANK=0.
class LitDataModule(LightningDataModule):
def __init__(self):
super().__init__()
self.prepare_data_per_node = True
Using a DataModule
The recommended way to use a DataModule is simply:
dm = MNISTDataModule()
model = Model()
trainer.fit(model, datamodule=dm)
trainer.test(datamodule=dm)
trainer.validate(datamodule=dm)
trainer.predict(datamodule=dm)
If you need information from the dataset to build your model, then run prepare_data and setup manually (Lightning ensures the method runs on the correct devices).
dm = MNISTDataModule()
dm.prepare_data()
dm.setup(stage="fit")
model = Model(num_classes=dm.num_classes, width=dm.width, vocab=dm.vocab)
trainer.fit(model, dm)
dm.setup(stage="test")
trainer.test(datamodule=dm)
You can access the current used datamodule of a trainer via trainer.datamodule
and the current used
dataloaders via the trainer properties train_dataloader()
,
val_dataloaders()
,
test_dataloaders()
, and
predict_dataloaders()
.
DataModules without Lightning
You can of course use DataModules in plain PyTorch code as well.
# download, etc...
dm = MNISTDataModule()
dm.prepare_data()
# splits/transforms
dm.setup(stage="fit")
# use data
for batch in dm.train_dataloader():
...
for batch in dm.val_dataloader():
...
dm.teardown(stage="fit")
# lazy load test data
dm.setup(stage="test")
for batch in dm.test_dataloader():
...
dm.teardown(stage="test")
But overall, DataModules encourage reproducibility by allowing all details of a dataset to be specified in a unified structure.
Hyperparameters in DataModules
Like LightningModules, DataModules support hyperparameters with the same API.
import lightning.pytorch as pl
class CustomDataModule(pl.LightningDataModule):
def __init__(self, *args, **kwargs):
super().__init__()
self.save_hyperparameters()
def configure_optimizers(self):
# access the saved hyperparameters
opt = optim.Adam(self.parameters(), lr=self.hparams.lr)
Refer to save_hyperparameters
in lightning module for more details.
Save DataModule state
When a checkpoint is created, it asks every DataModule for their state. If your DataModule defines the state_dict and load_state_dict methods, the checkpoint will automatically track and restore your DataModules.
class LitDataModule(pl.DataModuler):
def state_dict(self):
# track whatever you want here
state = {"current_train_batch_index": self.current_train_batch_index}
return state
def load_state_dict(self, state_dict):
# restore the state based on what you tracked in (def state_dict)
self.current_train_batch_index = state_dict["current_train_batch_index"]