trainer¶
Classes
Customize every aspect of training via flags |
Trainer to automate the training.
- class pytorch_lightning.trainer.trainer.Trainer(logger=True, checkpoint_callback=True, callbacks=None, default_root_dir=None, gradient_clip_val=0.0, gradient_clip_algorithm='norm', process_position=0, num_nodes=1, num_processes=1, devices=None, gpus=None, auto_select_gpus=False, tpu_cores=None, ipus=None, log_gpu_memory=None, progress_bar_refresh_rate=None, overfit_batches=0.0, track_grad_norm=- 1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=None, min_epochs=None, max_steps=None, min_steps=None, max_time=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, limit_predict_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=50, accelerator=None, sync_batchnorm=False, precision=32, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_n_epochs=0, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, plugins=None, amp_backend='native', amp_level='O2', distributed_backend=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', stochastic_weight_avg=False)[source]¶
Bases:
pytorch_lightning.trainer.properties.TrainerProperties
,pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin
,pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin
,pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin
,pytorch_lightning.trainer.logging.TrainerLoggingMixin
,pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin
,pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin
,pytorch_lightning.trainer.deprecated_api.DeprecatedTrainerAttributes
Customize every aspect of training via flags
- Parameters
accelerator¶ (
Union
[str
,Accelerator
,None
]) – Previously known as distributed_backend (dp, ddp, ddp2, etc…). Can also take in an accelerator object for custom hardware.accumulate_grad_batches¶ (
Union
[int
,Dict
[int
,int
],List
[list
]]) – Accumulates grads every k batches or as set up in the dict.amp_backend¶ (
str
) – The mixed precision backend to use (“native” or “apex”)amp_level¶ (
str
) – The optimization level to use (O1, O2, etc…).auto_lr_find¶ (
Union
[bool
,str
]) – If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name.auto_scale_batch_size¶ (
Union
[str
,bool
]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.auto_select_gpus¶ (
bool
) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.callbacks¶ (
Union
[List
[Callback
],Callback
,None
]) – Add a callback or list of callbacks.checkpoint_callback¶ (
bool
) – IfTrue
, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint incallbacks
.check_val_every_n_epoch¶ (
int
) – Check val every n train epochs.default_root_dir¶ (
Optional
[str
]) – Default path for logs and weights when no logger/ckpt_callback passed. Default:os.getcwd()
. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’deterministic¶ (
bool
) – If true enables cudnn.deterministic.devices¶ (
Union
[int
,str
,List
[int
],None
]) – Will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.distributed_backend¶ (
Optional
[str
]) – deprecated. Please use ‘accelerator’fast_dev_run¶ (
Union
[int
,bool
]) – runs n if set ton
(int) else 1 if set toTrue
batch(es) of train, val and test to find any bugs (ie: a sort of unit test).flush_logs_every_n_steps¶ (
int
) – How often to flush logs to disk (defaults to every 100 steps).gpus¶ (
Union
[int
,str
,List
[int
],None
]) – number of gpus to train on (int) or which GPUs to train on (list or str) applied per nodegradient_clip_algorithm¶ (
str
) – ‘value’ means clip_by_value, ‘norm’ means clip_by_norm. Default: ‘norm’limit_train_batches¶ (
Union
[int
,float
]) – How much of training dataset to check (float = fraction, int = num_batches)limit_val_batches¶ (
Union
[int
,float
]) – How much of validation dataset to check (float = fraction, int = num_batches)limit_test_batches¶ (
Union
[int
,float
]) – How much of test dataset to check (float = fraction, int = num_batches)limit_predict_batches¶ (
Union
[int
,float
]) – How much of prediction dataset to check (float = fraction, int = num_batches)logger¶ (
Union
[LightningLoggerBase
,Iterable
[LightningLoggerBase
],bool
]) – Logger (or iterable collection of loggers) for experiment tracking. ATrue
value uses the defaultTensorBoardLogger
.False
will disable logging. If multiple loggers are provided and the save_dir property of that logger is not set, local files (checkpoints, profiler traces, etc.) are saved indefault_root_dir
rather than in thelog_dir
of any of the individual loggers.log_gpu_memory¶ (
Optional
[str
]) – None, ‘min_max’, ‘all’. Might slow performancelog_every_n_steps¶ (
int
) – How often to log within steps (defaults to every 50 steps).prepare_data_per_node¶ (
bool
) – If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare dataprocess_position¶ (
int
) – orders the progress bar when running multiple models on same machine.progress_bar_refresh_rate¶ (
Optional
[int
]) – How often to refresh progress bar (in steps). Value0
disables progress bar. Ignored when a custom progress bar is passed tocallbacks
. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).profiler¶ (
Union
[BaseProfiler
,str
,None
]) – To profile individual steps during training and assist in identifying bottlenecks.overfit_batches¶ (
Union
[int
,float
]) – Overfit a fraction of training data (float) or a set number of batches (int).plugins¶ (
Union
[List
[Union
[Plugin
,ClusterEnvironment
,str
]],Plugin
,ClusterEnvironment
,str
,None
]) – Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.precision¶ (
int
) – Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or TPUs.max_epochs¶ (
Optional
[int
]) – Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults tomax_epochs
= 1000.min_epochs¶ (
Optional
[int
]) – Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults tomin_epochs
= 1.max_steps¶ (
Optional
[int
]) – Stop training after this number of steps. Disabled by default (None).min_steps¶ (
Optional
[int
]) – Force training for at least these number of steps. Disabled by default (None).max_time¶ (
Union
[str
,timedelta
,Dict
[str
,int
],None
]) – Stop training after this amount of time has passed. Disabled by default (None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as adatetime.timedelta
, or a dictionary with keys that will be passed todatetime.timedelta
.num_nodes¶ (
int
) – number of GPU nodes for distributed training.num_processes¶ (
int
) – number of processes for distributed training with distributed_backend=”ddp_cpu”num_sanity_val_steps¶ (
int
) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.reload_dataloaders_every_n_epochs¶ (
int
) – Set to a non-negative integer to reload dataloaders every n epochs. Default: 0reload_dataloaders_every_epoch¶ (
bool
) –Set to True to reload dataloaders every epoch.
Deprecated since version v1.4:
reload_dataloaders_every_epoch
has been deprecated in v1.4 and will be removed in v1.6. Please usereload_dataloaders_every_n_epochs
.replace_sampler_ddp¶ (
bool
) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will addshuffle=True
for train sampler andshuffle=False
for val/test sampler. If you want to customize it, you can setreplace_sampler_ddp=False
and add your own distributed sampler.resume_from_checkpoint¶ (
Union
[str
,Path
,None
]) – Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.sync_batchnorm¶ (
bool
) – Synchronize batch norm layers between process groups/whole world.terminate_on_nan¶ (
bool
) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.tpu_cores¶ (
Union
[int
,str
,List
[int
],None
]) – How many TPU cores to train on (1 or 8) / Single TPU to train on [1]track_grad_norm¶ (
Union
[int
,float
,str
]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm.truncated_bptt_steps¶ (
Optional
[int
]) – Deprecated in v1.3 to be removed in 1.5. Please usetruncated_bptt_steps
instead.val_check_interval¶ (
Union
[int
,float
]) – How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches).weights_summary¶ (
Optional
[str
]) – Prints a summary of the weights when training begins.weights_save_path¶ (
Optional
[str
]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’ Defaults to default_root_dir.move_metrics_to_cpu¶ (
bool
) – Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention.multiple_trainloader_mode¶ (
str
) – How to loop over the datasets when there are multiple train loaders. In ‘max_size_cycle’ mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In ‘min_size’ mode, all the datasets reload when reaching the minimum length of datasets.stochastic_weight_avg¶ (
bool
) – Whether to use Stochastic Weight Averaging (SWA) <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_
- fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, train_dataloader=None)[source]¶
Runs the full optimization routine.
- Parameters
model¶ (
LightningModule
) – Model to fit.train_dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],Sequence
[Sequence
[DataLoader
]],Sequence
[Dict
[str
,DataLoader
]],Dict
[str
,DataLoader
],Dict
[str
,Dict
[str
,DataLoader
]],Dict
[str
,Sequence
[DataLoader
]],LightningDataModule
,None
]) – A collection oftorch.utils.data.DataLoader
or aLightningDataModule
specifying training samples. In the case of multiple dataloaders, please see this page.val_dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],None
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying validation samples.datamodule¶ (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.
- Return type
- predict(model=None, dataloaders=None, datamodule=None, return_predictions=None, ckpt_path='best')[source]¶
Separates from fit to make sure you never run on your predictions set until you want to. This will call the model forward function to compute predictions.
- Parameters
model¶ (
Optional
[LightningModule
]) – The model to predict with.dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],LightningDataModule
,None
]) – Atorch.utils.data.DataLoader
or a sequence of them, or aLightningDataModule
specifying prediction samples.datamodule¶ (
Optional
[LightningDataModule
]) – The datamodule with a predict_dataloader method that returns one or more dataloaders.return_predictions¶ (
Optional
[bool
]) – Whether to return predictions.True
by default except when an accelerator that spawns processes is used (not supported).ckpt_path¶ (
Optional
[str
]) – Eitherbest
or path to the checkpoint you wish to use to predict. IfNone
, use the current weights of the model. When the model is given as argument, this parameter will not apply.
- Return type
- Returns
Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.
- test(model=None, dataloaders=None, ckpt_path='best', verbose=True, datamodule=None, test_dataloaders=None)[source]¶
Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.
- Parameters
model¶ (
Optional
[LightningModule
]) – The model to test.dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],LightningDataModule
,None
]) – Atorch.utils.data.DataLoader
or a sequence of them, or aLightningDataModule
specifying test samples.ckpt_path¶ (
Optional
[str
]) – Eitherbest
or path to the checkpoint you wish to test. IfNone
, use the current weights of the model. When the model is given as argument, this parameter will not apply.datamodule¶ (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.
- Return type
- Returns
List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like
test_step()
,test_epoch_end()
, etc. The length of the list corresponds to the number of test dataloaders used.
- tune(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None, train_dataloader=None)[source]¶
Runs routines to tune hyperparameters before training.
- Parameters
model¶ (
LightningModule
) – Model to tune.train_dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],Sequence
[Sequence
[DataLoader
]],Sequence
[Dict
[str
,DataLoader
]],Dict
[str
,DataLoader
],Dict
[str
,Dict
[str
,DataLoader
]],Dict
[str
,Sequence
[DataLoader
]],LightningDataModule
,None
]) – A collection oftorch.utils.data.DataLoader
or aLightningDataModule
specifying training samples. In the case of multiple dataloaders, please see this page.val_dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],None
]) – Atorch.utils.data.DataLoader
or a sequence of them specifying validation samples.datamodule¶ (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.scale_batch_size_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Arguments forscale_batch_size()
lr_find_kwargs¶ (
Optional
[Dict
[str
,Any
]]) – Arguments forlr_find()
- Return type
- validate(model=None, dataloaders=None, ckpt_path='best', verbose=True, datamodule=None, val_dataloaders=None)[source]¶
Perform one evaluation epoch over the validation set.
- Parameters
model¶ (
Optional
[LightningModule
]) – The model to validate.dataloaders¶ (
Union
[DataLoader
,Sequence
[DataLoader
],LightningDataModule
,None
]) – Atorch.utils.data.DataLoader
or a sequence of them, or aLightningDataModule
specifying validation samples.ckpt_path¶ (
Optional
[str
]) – Eitherbest
or path to the checkpoint you wish to validate. IfNone
, use the current weights of the model. When the model is given as argument, this parameter will not apply.datamodule¶ (
Optional
[LightningDataModule
]) – An instance ofLightningDataModule
.
- Return type
- Returns
List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like
validation_step()
,validation_epoch_end()
, etc. The length of the list corresponds to the number of validation dataloaders used.