LightningDataModule¶
- class pytorch_lightning.core.LightningDataModule(train_transforms=None, val_transforms=None, test_transforms=None, dims=None)[source]¶
Bases:
pytorch_lightning.core.hooks.CheckpointHooks
,pytorch_lightning.core.hooks.DataHooks
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
class MyDataModule(LightningDataModule): def __init__(self): super().__init__() def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self): train_split = Dataset(...) return DataLoader(train_split) def val_dataloader(self): val_split = Dataset(...) return DataLoader(val_split) def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) def teardown(self): # clean up after fit or test # called on every process in DDP
- classmethod add_argparse_args(parent_parser, **kwargs)[source]¶
Extends existing argparse by default LightningDataModule attributes.
- Return type
- classmethod from_argparse_args(args, **kwargs)[source]¶
Create an instance from CLI arguments.
- Parameters
args¶ (
Union
[Namespace
,ArgumentParser
]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to theLightningDataModule
.**kwargs¶ – Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.
Example:
parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) module = LightningDataModule.from_argparse_args(args)
- classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, batch_size=1, num_workers=0)[source]¶
Create an instance from torch.utils.data.Dataset.
- Parameters
train_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],Mapping
[str
,Dataset
],None
]) – (optional) Dataset to be used for train_dataloader()val_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for val_dataloader()test_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for test_dataloader()batch_size¶ (
int
) – Batch size to use for each dataloader. Default is 1.num_workers¶ (
int
) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available.
- classmethod get_init_arguments_and_types()[source]¶
Scans the DataModule signature and returns argument names, types and default values.
- Returns
(argument name, set with argument types, argument default value).
- Return type
List with tuples of 3 values
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload datamodule state given datamodule state_dict.
- size(dim=None)[source]¶
Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.
Deprecated since version v1.5: Will be removed in v1.7.0.
- state_dict()[source]¶
Called when saving a checkpoint, implement to generate and save datamodule state.
- property dims¶
A tuple describing the shape of your data. Extra functionality exposed in
size
.Deprecated since version v1.5: Will be removed in v1.7.0.
- property test_transforms¶
Optional transforms (or collection of transforms) you can apply to test dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.
- property train_transforms¶
Optional transforms (or collection of transforms) you can apply to train dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.
- property val_transforms¶
Optional transforms (or collection of transforms) you can apply to validation dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.