datamodule¶
Classes
A DataModule standardizes the training, val, test splits, data preparation and transforms. |
LightningDataModule for loading DataLoaders with ease.
- class pytorch_lightning.core.datamodule.LightningDataModule(*args: Any, **kwargs: Any)[source]¶
Bases:
pytorch_lightning.core.hooks.CheckpointHooks
,pytorch_lightning.core.hooks.DataHooks
,pytorch_lightning.core.mixins.hparams_mixin.HyperparametersMixin
A DataModule standardizes the training, val, test splits, data preparation and transforms. The main advantage is consistent data splits, data preparation and transforms across models.
Example:
class MyDataModule(LightningDataModule): def __init__(self): super().__init__() def prepare_data(self): # download, split, etc... # only called on 1 GPU/TPU in distributed def setup(self, stage): # make assignments here (val/train/test split) # called on every process in DDP def train_dataloader(self): train_split = Dataset(...) return DataLoader(train_split) def val_dataloader(self): val_split = Dataset(...) return DataLoader(val_split) def test_dataloader(self): test_split = Dataset(...) return DataLoader(test_split) def teardown(self): # clean up after fit or test # called on every process in DDP
A DataModule implements 6 key methods:
prepare_data (things to do on 1 GPU/TPU not on every GPU/TPU in distributed mode).
setup (things to do on every accelerator in distributed mode).
train_dataloader the training dataloader.
val_dataloader the val dataloader(s).
test_dataloader the test dataloader(s).
teardown (things to do on every accelerator in distributed mode when finished)
This allows you to share a full dataset without explaining how to download, split, transform, and process the data
- classmethod add_argparse_args(parent_parser, **kwargs)[source]¶
Extends existing argparse by default LightningDataModule attributes.
- Return type
- classmethod from_argparse_args(args, **kwargs)[source]¶
Create an instance from CLI arguments.
- Parameters
args¶ (
Union
[Namespace
,ArgumentParser
]) – The parser or namespace to take arguments from. Only known arguments will be parsed and passed to theLightningDataModule
.**kwargs¶ – Additional keyword arguments that may override ones in the parser or namespace. These must be valid DataModule arguments.
Example:
parser = ArgumentParser(add_help=False) parser = LightningDataModule.add_argparse_args(parser) module = LightningDataModule.from_argparse_args(args)
- classmethod from_datasets(train_dataset=None, val_dataset=None, test_dataset=None, batch_size=1, num_workers=0)[source]¶
Create an instance from torch.utils.data.Dataset.
- Parameters
train_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],Mapping
[str
,Dataset
],None
]) – (optional) Dataset to be used for train_dataloader()val_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for val_dataloader()test_dataset¶ (
Union
[Dataset
,Sequence
[Dataset
],None
]) – (optional) Dataset or list of Dataset to be used for test_dataloader()batch_size¶ (
int
) – Batch size to use for each dataloader. Default is 1.num_workers¶ (
int
) – Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process. Number of CPUs available.
- classmethod get_init_arguments_and_types()[source]¶
Scans the DataModule signature and returns argument names, types and default values.
- Returns
(argument name, set with argument types, argument default value).
- Return type
List with tuples of 3 values
- size(dim=None)[source]¶
Return the dimension of each input either as a tuple or list of tuples. You can index this just as you would with a torch tensor.
Deprecated since version v1.5: Will be removed in v1.7.0.
- property dims¶
A tuple describing the shape of your data. Extra functionality exposed in
size
.Deprecated since version v1.5: Will be removed in v1.7.0.
- property has_prepared_data: bool¶
Return bool letting you know if
datamodule.prepare_data()
has been called or not.- Returns
True if
datamodule.prepare_data()
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_setup_fit: bool¶
Return bool letting you know if
datamodule.setup(stage='fit')
has been called or not.- Returns
True
if datamodule.setup(stage='fit')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_setup_predict: bool¶
Return bool letting you know if
datamodule.setup(stage='predict')
has been called or not.- Returns
True if
datamodule.setup(stage='predict')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_setup_test: bool¶
Return bool letting you know if
datamodule.setup(stage='test')
has been called or not.- Returns
True if
datamodule.setup(stage='test')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_setup_validate: bool¶
Return bool letting you know if
datamodule.setup(stage='validate')
has been called or not.- Returns
True if
datamodule.setup(stage='validate')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_teardown_fit: bool¶
Return bool letting you know if
datamodule.teardown(stage='fit')
has been called or not.- Returns
True
if datamodule.teardown(stage='fit')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_teardown_predict: bool¶
Return bool letting you know if
datamodule.teardown(stage='predict')
has been called or not.- Returns
True if
datamodule.teardown(stage='predict')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_teardown_test: bool¶
Return bool letting you know if
datamodule.teardown(stage='test')
has been called or not.- Returns
True if
datamodule.teardown(stage='test')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property has_teardown_validate: bool¶
Return bool letting you know if
datamodule.teardown(stage='validate')
has been called or not.- Returns
True if
datamodule.teardown(stage='validate')
has been called. False by default.- Return type
Deprecated since version v1.4: Will be removed in v1.6.0.
- property test_transforms¶
Optional transforms (or collection of transforms) you can apply to test dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.
- property train_transforms¶
Optional transforms (or collection of transforms) you can apply to train dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.
- property val_transforms¶
Optional transforms (or collection of transforms) you can apply to validation dataset.
Deprecated since version v1.5: Will be removed in v1.7.0.