Customize every aspect of training via flags

Trainer to automate the training.

class pytorch_lightning.trainer.trainer.Trainer(logger=True, checkpoint_callback=True, callbacks=None, default_root_dir=None, gradient_clip_val=0.0, gradient_clip_algorithm='norm', process_position=0, num_nodes=1, num_processes=1, devices=None, gpus=None, auto_select_gpus=False, tpu_cores=None, ipus=None, log_gpu_memory=None, progress_bar_refresh_rate=None, overfit_batches=0.0, track_grad_norm=- 1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=1, max_epochs=None, min_epochs=None, max_steps=None, min_steps=None, max_time=None, limit_train_batches=1.0, limit_val_batches=1.0, limit_test_batches=1.0, limit_predict_batches=1.0, val_check_interval=1.0, flush_logs_every_n_steps=100, log_every_n_steps=50, accelerator=None, sync_batchnorm=False, precision=32, weights_summary='top', weights_save_path=None, num_sanity_val_steps=2, truncated_bptt_steps=None, resume_from_checkpoint=None, profiler=None, benchmark=False, deterministic=False, reload_dataloaders_every_n_epochs=0, reload_dataloaders_every_epoch=False, auto_lr_find=False, replace_sampler_ddp=True, terminate_on_nan=False, auto_scale_batch_size=False, prepare_data_per_node=True, plugins=None, amp_backend='native', amp_level='O2', distributed_backend=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', stochastic_weight_avg=False)[source]

Bases:, pytorch_lightning.trainer.callback_hook.TrainerCallbackHookMixin, pytorch_lightning.trainer.model_hooks.TrainerModelHooksMixin, pytorch_lightning.trainer.optimizers.TrainerOptimizersMixin, pytorch_lightning.trainer.logging.TrainerLoggingMixin, pytorch_lightning.trainer.training_tricks.TrainerTrainingTricksMixin, pytorch_lightning.trainer.data_loading.TrainerDataLoadingMixin, pytorch_lightning.trainer.deprecated_api.DeprecatedTrainerAttributes

Customize every aspect of training via flags

  • accelerator (Union[str, Accelerator, None]) – Previously known as distributed_backend (dp, ddp, ddp2, etc…). Can also take in an accelerator object for custom hardware.

  • accumulate_grad_batches (Union[int, Dict[int, int], List[list]]) – Accumulates grads every k batches or as set up in the dict.

  • amp_backend (str) – The mixed precision backend to use (“native” or “apex”)

  • amp_level (str) – The optimization level to use (O1, O2, etc…).

  • auto_lr_find (Union[bool, str]) – If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name.

  • auto_scale_batch_size (Union[str, bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search.

  • auto_select_gpus (bool) – If enabled and gpus is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them.

  • benchmark (bool) – If true enables cudnn.benchmark.

  • callbacks (Union[List[Callback], Callback, None]) – Add a callback or list of callbacks.

  • checkpoint_callback (bool) – If True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in callbacks.

  • check_val_every_n_epoch (int) – Check val every n train epochs.

  • default_root_dir (Optional[str]) – Default path for logs and weights when no logger/ckpt_callback passed. Default: os.getcwd(). Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’

  • deterministic (bool) – If true enables cudnn.deterministic.

  • devices (Union[int, str, List[int], None]) – Will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.

  • distributed_backend (Optional[str]) – deprecated. Please use ‘accelerator’

  • fast_dev_run (Union[int, bool]) – runs n if set to n (int) else 1 if set to True batch(es) of train, val and test to find any bugs (ie: a sort of unit test).

  • flush_logs_every_n_steps (int) – How often to flush logs to disk (defaults to every 100 steps).

  • gpus (Union[int, str, List[int], None]) – number of gpus to train on (int) or which GPUs to train on (list or str) applied per node

  • gradient_clip_val (float) – 0 means don’t clip.

  • gradient_clip_algorithm (str) – ‘value’ means clip_by_value, ‘norm’ means clip_by_norm. Default: ‘norm’

  • limit_train_batches (Union[int, float]) – How much of training dataset to check (float = fraction, int = num_batches)

  • limit_val_batches (Union[int, float]) – How much of validation dataset to check (float = fraction, int = num_batches)

  • limit_test_batches (Union[int, float]) – How much of test dataset to check (float = fraction, int = num_batches)

  • limit_predict_batches (Union[int, float]) – How much of prediction dataset to check (float = fraction, int = num_batches)

  • logger (Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]) – Logger (or iterable collection of loggers) for experiment tracking. A True value uses the default TensorBoardLogger. False will disable logging. If multiple loggers are provided and the save_dir property of that logger is not set, local files (checkpoints, profiler traces, etc.) are saved in default_root_dir rather than in the log_dir of any of the individual loggers.

  • log_gpu_memory (Optional[str]) – None, ‘min_max’, ‘all’. Might slow performance

  • log_every_n_steps (int) – How often to log within steps (defaults to every 50 steps).

  • prepare_data_per_node (bool) – If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data

  • process_position (int) – orders the progress bar when running multiple models on same machine.

  • progress_bar_refresh_rate (Optional[int]) – How often to refresh progress bar (in steps). Value 0 disables progress bar. Ignored when a custom progress bar is passed to callbacks. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.).

  • profiler (Union[BaseProfiler, str, None]) – To profile individual steps during training and assist in identifying bottlenecks.

  • overfit_batches (Union[int, float]) – Overfit a fraction of training data (float) or a set number of batches (int).

  • plugins (Union[List[Union[Plugin, ClusterEnvironment, str]], Plugin, ClusterEnvironment, str, None]) – Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.

  • precision (int) – Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or TPUs.

  • max_epochs (Optional[int]) – Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to max_epochs = 1000.

  • min_epochs (Optional[int]) – Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults to min_epochs = 1.

  • max_steps (Optional[int]) – Stop training after this number of steps. Disabled by default (None).

  • min_steps (Optional[int]) – Force training for at least these number of steps. Disabled by default (None).

  • max_time (Union[str, timedelta, Dict[str, int], None]) – Stop training after this amount of time has passed. Disabled by default (None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a datetime.timedelta, or a dictionary with keys that will be passed to datetime.timedelta.

  • num_nodes (int) – number of GPU nodes for distributed training.

  • num_processes (int) – number of processes for distributed training with distributed_backend=”ddp_cpu”

  • num_sanity_val_steps (int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders.

  • reload_dataloaders_every_n_epochs (int) – Set to a non-negative integer to reload dataloaders every n epochs. Default: 0

  • reload_dataloaders_every_epoch (bool) –

    Set to True to reload dataloaders every epoch.

    Deprecated since version v1.4: reload_dataloaders_every_epoch has been deprecated in v1.4 and will be removed in v1.6. Please use reload_dataloaders_every_n_epochs.

  • replace_sampler_ddp (bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add shuffle=True for train sampler and shuffle=False for val/test sampler. If you want to customize it, you can set replace_sampler_ddp=False and add your own distributed sampler.

  • resume_from_checkpoint (Union[str, Path, None]) – Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.

  • sync_batchnorm (bool) – Synchronize batch norm layers between process groups/whole world.

  • terminate_on_nan (bool) – If set to True, will terminate training (by raising a ValueError) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf.

  • tpu_cores (Union[int, str, List[int], None]) – How many TPU cores to train on (1 or 8) / Single TPU to train on [1]

  • ipus (Optional[int]) – How many IPUs to train on.

  • track_grad_norm (Union[int, float, str]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm.

  • truncated_bptt_steps (Optional[int]) – Deprecated in v1.3 to be removed in 1.5. Please use truncated_bptt_steps instead.

  • val_check_interval (Union[int, float]) – How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches).

  • weights_summary (Optional[str]) – Prints a summary of the weights when training begins.

  • weights_save_path (Optional[str]) – Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in default_root_dir. Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’ Defaults to default_root_dir.

  • move_metrics_to_cpu (bool) – Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention.

  • multiple_trainloader_mode (str) – How to loop over the datasets when there are multiple train loaders. In ‘max_size_cycle’ mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In ‘min_size’ mode, all the datasets reload when reaching the minimum length of datasets.

  • stochastic_weight_avg (bool) – Whether to use Stochastic Weight Averaging (SWA) <>_

fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, train_dataloader=None)[source]

Runs the full optimization routine.

Return type


predict(model=None, dataloaders=None, datamodule=None, return_predictions=None, ckpt_path='best')[source]

Separates from fit to make sure you never run on your predictions set until you want to. This will call the model forward function to compute predictions.

Return type

Union[List[Any], List[List[Any]], None]


Returns a list of dictionaries, one for each provided dataloader containing their respective predictions.

test(model=None, dataloaders=None, ckpt_path='best', verbose=True, datamodule=None, test_dataloaders=None)[source]

Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to.

Return type

List[Dict[str, float]]


List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like test_step(), test_epoch_end(), etc. The length of the list corresponds to the number of test dataloaders used.

tune(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None, train_dataloader=None)[source]

Runs routines to tune hyperparameters before training.

Return type

Dict[str, Union[int, _LRFinder, None]]

validate(model=None, dataloaders=None, ckpt_path='best', verbose=True, datamodule=None, val_dataloaders=None)[source]

Perform one evaluation epoch over the validation set.

Return type

List[Dict[str, float]]


List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like validation_step(), validation_epoch_end(), etc. The length of the list corresponds to the number of validation dataloaders used.