Shortcuts

LightningDataModule

class pytorch_lightning.core.LightningDataModule(train_transforms=None, val_transforms=None, test_transforms=None, dims=None)[source]

Bases: pytorch_lightning.core.hooks.CheckpointHooks, pytorch_lightning.core.hooks.DataHooks, pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin

A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.

Example:

class MyDataModule(LightningDataModule):
    def __init__(self):
        super().__init__()
    def prepare_data(self):
        # download, split, etc...
        # only called on 1 GPU/TPU in distributed
    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
    def train_dataloader(self):
        train_split = Dataset(...)
        return DataLoader(train_split)
    def val_dataloader(self):
        val_split = Dataset(...)
        return DataLoader(val_split)
    def test_dataloader(self):
        test_split = Dataset(...)
        return DataLoader(test_split)
    def teardown(self):
        # clean up after fit or test
        # called on every process in DDP
classmethod add_argparse_args(parent_parser, **kwargs)[source]

Extends existing argparse by default LightningDataModule attributes.

Return type

ArgumentParser

classmethod from_argparse_args(args, **kwargs)[source]

Create an instance from CLI arguments.

Parameters
  • args (Union[Namespace, ArgumentParser]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the LightningDataModule.

  • **kwargs – Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.

Example:

parser = ArgumentParser(add_help=False)
parser = LightningDataModule.add_argparse_args(parser)
module = LightningDataModule.from_argparse_args(args)
classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, batch_size=1, num_workers=0)[source]

Create an instance from torch.utils.data.Dataset.

Parameters
  • train_dataset (Union[Dataset, Sequence[Dataset], Mapping[str, Dataset], None]) – (optional) Dataset to be used for train_dataloader()

  • val_dataset (Union[Dataset, Sequence[Dataset], None]) – (optional) Dataset or list of Dataset to be used for val_dataloader()

  • test_dataset (Union[Dataset, Sequence[Dataset], None]) – (optional) Dataset or list of Dataset to be used for test_dataloader()

  • batch_size (int) – Batch size to use for each dataloader. Default is 1.

  • num_workers (int) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available.

classmethod get_init_arguments_and_types()[source]

Scans the DataModule signature and returns argument names, types and default values.

Returns

(argument name, set with argument types, argument default value).

Return type

List with tuples of 3 values

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.

Parameters

state_dict (Dict[str, Any]) – the datamodule state returned by state_dict.

Return type

None

size(dim=None)[source]

Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.

Deprecated since version v1.5: Will be removed in v1.7.0.

Return type

Union[Tuple, List[Tuple]]

state_dict()[source]

Called when saving a checkpoint, implement to generate and save datamodule state.

Return type

Dict[str, Any]

Returns

A dictionary containing datamodule state.

property dims

A tuple describing the shape of your data. Extra functionality exposed in size.

Deprecated since version v1.5: Will be removed in v1.7.0.

property test_transforms

Optional transforms (or collection of transforms) you can apply to test dataset.

Deprecated since version v1.5: Will be removed in v1.7.0.

property train_transforms

Optional transforms (or collection of transforms) you can apply to train dataset.

Deprecated since version v1.5: Will be removed in v1.7.0.

property val_transforms

Optional transforms (or collection of transforms) you can apply to validation dataset.

Deprecated since version v1.5: Will be removed in v1.7.0.