LightningDataModule

class lightning.pytorch.core.LightningDataModule[source]

Bases: DataHooks, HyperparametersMixin

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

import lightning as L
import torch.utils.data as data
from lightning.pytorch.demos.boring_classes import RandomDataset

class MyDataModule(L.LightningDataModule):
    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed
        ...

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        dataset = RandomDataset(1, 100)
        self.train, self.val, self.test = data.random_split(
            dataset, [80, 10, 10], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return data.DataLoader(self.train)

    def val_dataloader(self):
        return data.DataLoader(self.val)

    def test_dataloader(self):
        return data.DataLoader(self.test)

    def teardown(self):
        # clean up state after the trainer stops, delete files...
        # called on every process in DDP
        ...
prepare_data_per_node

If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data.

allow_zero_length_dataloader_with_multiple_devices

If True, dataloader with zero length within local rank is allowed. Default value is False.

classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, predict_dataset=None, batch_size=1, num_workers=0, **datamodule_kwargs)[source]

Create an instance from torch.utils.data.Dataset.

Parameters:
  • train_dataset (Union[Dataset, Iterable[Dataset], None]) – Optional dataset or iterable of datasets to be used for train_dataloader()

  • val_dataset (Union[Dataset, Iterable[Dataset], None]) – Optional dataset or iterable of datasets to be used for val_dataloader()

  • test_dataset (Union[Dataset, Iterable[Dataset], None]) – Optional dataset or iterable of datasets to be used for test_dataloader()

  • predict_dataset (Union[Dataset, Iterable[Dataset], None]) – Optional dataset or iterable of datasets to be used for predict_dataloader()

  • batch_size (int) – Batch size to use for each dataloader. Default is 1. This parameter gets forwarded to the __init__ if the datamodule has such a name defined in its signature.

  • num_workers (int) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available. This parameter gets forwarded to the __init__ if the datamodule has such a name defined in its signature.

  • **datamodule_kwargs (Any) – Additional parameters that get passed down to the datamodule’s __init__.

Return type:

LightningDataModule

load_from_checkpoint(checkpoint_path, map_location=None, hparams_file=None, **kwargs)[source]

Primary way of loading a datamodule from a checkpoint. When Lightning saves a checkpoint it stores the arguments passed to __init__ in the checkpoint under "datamodule_hyper_parameters".

Any arguments specified through **kwargs will override args stored in "datamodule_hyper_parameters".

Parameters:
  • checkpoint_path (Union[str, Path, IO]) – Path to checkpoint. This can also be a URL, or file-like object

  • map_location (Union[device, str, int, Callable[[UntypedStorage, str], Optional[UntypedStorage]], Dict[Union[device, str, int], Union[device, str, int]], None]) – If your checkpoint saved a GPU model and you now load on CPUs or a different number of GPUs, use this to map to the new setup. The behaviour is the same as in torch.load().

  • hparams_file (Union[str, Path, None]) –

    Optional path to a .yaml or .csv file with hierarchical structure as in this example:

    dataloader:
        batch_size: 32
    

    You most likely won’t need this since Lightning will always save the hyperparameters to the checkpoint. However, if your checkpoint weights don’t have the hyperparameters saved, use this method to pass in a .yaml file with the hparams you’d like to use. These will be converted into a dict and passed into your LightningDataModule for use.

    If your datamodule’s hparams argument is Namespace and .yaml file has hierarchical structure, you need to refactor your datamodule to treat hparams as dict.

  • **kwargs (Any) – Any extra keyword args needed to init the datamodule. Can also be used to override saved hyperparameter values.

Return type:

Self

Returns:

LightningDataModule instance with loaded weights and hyperparameters (if available).

Note

load_from_checkpoint is a class method. You must use your LightningDataModule class to call it instead of the LightningDataModule instance, or a TypeError will be raised.

Example:

# load weights without mapping ...
datamodule = MyLightningDataModule.load_from_checkpoint('path/to/checkpoint.ckpt')

# or load weights and hyperparameters from separate files.
datamodule = MyLightningDataModule.load_from_checkpoint(
    'path/to/checkpoint.ckpt',
    hparams_file='/path/to/hparams_file.yaml'
)

# override some of the params with new values
datamodule = MyLightningDataModule.load_from_checkpoint(
    PATH,
    batch_size=32,
    num_workers=10,
)
load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters:

state_dict (Dict[str, Any]) – the datamodule state returned by state_dict.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate and save datamodule state.

Return type:

Dict[str, Any]

Returns:

A dictionary containing datamodule state.