Trainer¶
- class pytorch_lightning.trainer.trainer.Trainer(logger=True, enable_checkpointing=True, callbacks=None, default_root_dir=None, gradient_clip_val=None, gradient_clip_algorithm=None, num_nodes=1, num_processes=None, devices=None, gpus=None, auto_select_gpus=None, tpu_cores=None, ipus=None, enable_progress_bar=True, overfit_batches=0.0, track_grad_norm=- 1, check_val_every_n_epoch=1, fast_dev_run=False, accumulate_grad_batches=None, max_epochs=None, min_epochs=None, max_steps=- 1, min_steps=None, max_time=None, limit_train_batches=None, limit_val_batches=None, limit_test_batches=None, limit_predict_batches=None, val_check_interval=None, log_every_n_steps=50, accelerator=None, strategy=None, sync_batchnorm=False, precision=32, enable_model_summary=True, num_sanity_val_steps=2, resume_from_checkpoint=None, profiler=None, benchmark=None, deterministic=None, reload_dataloaders_every_n_epochs=0, auto_lr_find=False, replace_sampler_ddp=True, detect_anomaly=False, auto_scale_batch_size=False, plugins=None, amp_backend=None, amp_level=None, move_metrics_to_cpu=False, multiple_trainloader_mode='max_size_cycle', inference_mode=True)[source]¶
- Bases: - object- Customize every aspect of training via flags. - Parameters
- accelerator¶ ( - Union[- str,- Accelerator,- None]) – Supports passing different accelerator types (“cpu”, “gpu”, “tpu”, “ipu”, “hpu”, “mps”, “auto”) as well as custom accelerator instances.
- accumulate_grad_batches¶ ( - Union[- int,- Dict[- int,- int],- None]) – Accumulates grads every k batches or as set up in the dict. Default:- None.
- amp_backend¶ ( - Optional[- str]) –- The mixed precision backend to use (“native” or “apex”). Default: - 'native''.- Deprecated since version v1.9: Setting - amp_backendinside the- Traineris deprecated in v1.8.0 and will be removed in v2.0.0. This argument was only relevant for apex which is being removed.
- The optimization level to use (O1, O2, etc…). By default it will be set to “O2” if - amp_backendis set to “apex”.- Deprecated since version v1.8: Setting - amp_levelinside the- Traineris deprecated in v1.8.0 and will be removed in v2.0.0.
- auto_lr_find¶ ( - Union[- bool,- str]) – If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name. Default:- False.
- auto_scale_batch_size¶ ( - Union[- str,- bool]) – If set to True, will initially run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule or LightningDataModule depending on your setup. Additionally, can be set to either power that estimates the batch size through a power search or binsearch that estimates the batch size through a binary search. Default:- False.
- auto_select_gpus¶ ( - Optional[- bool]) –- If enabled and - gpusor- devicesis an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in “exclusive mode”, such that only one process at a time can access them. Default:- False.- Deprecated since version v1.9: - auto_select_gpushas been deprecated in v1.9.0 and will be removed in v2.0.0. Please use the function- find_usable_cuda_devices()instead.
- benchmark¶ ( - Optional[- bool]) – The value (- Trueor- False) to set- torch.backends.cudnn.benchmarkto. The value for- torch.backends.cudnn.benchmarkset in the current session will be used (- Falseif not manually set). If- deterministicis set to- True, this will default to- False. Override to manually set a different value. Default:- None.
- callbacks¶ ( - Union[- List[- Callback],- Callback,- None]) – Add a callback or list of callbacks. Default:- None.
- enable_checkpointing¶ ( - bool) – If- True, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in- callbacks. Default:- True.
- check_val_every_n_epoch¶ ( - Optional[- int]) – Perform a validation loop every after every N training epochs. If- None, validation will be done solely based on the number of training batches, requiring- val_check_intervalto be an integer value. Default:- 1.
- default_root_dir¶ ( - Union[- str,- Path,- None]) – Default path for logs and weights when no logger/ckpt_callback passed. Default:- os.getcwd(). Can be remote file paths such as s3://mybucket/path or ‘hdfs://path/’
- detect_anomaly¶ ( - bool) – Enable anomaly detection for the autograd engine. Default:- False.
- deterministic¶ ( - Union[- bool,- Literal[‘warn’],- None]) – If- True, sets whether PyTorch operations must use deterministic algorithms. Set to- "warn"to use deterministic algorithms whenever possible, throwing warnings on operations that don’t support deterministic mode (requires PyTorch 1.11+). If not set, defaults to- False. Default:- None.
- devices¶ ( - Union[- List[- int],- str,- int,- None]) – Will be mapped to either gpus, tpu_cores, num_processes or ipus, based on the accelerator type.
- fast_dev_run¶ ( - Union[- int,- bool]) – Runs n if set to- n(int) else 1 if set to- Truebatch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default:- False.
- gpus¶ ( - Union[- List[- int],- str,- int,- None]) –- Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node Default: - None.- Deprecated since version v1.7: - gpushas been deprecated in v1.7 and will be removed in v2.0. Please use- accelerator='gpu'and- devices=xinstead.
- gradient_clip_val¶ ( - Union[- int,- float,- None]) – The value at which to clip gradients. Passing- gradient_clip_val=Nonedisables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default:- None.
- gradient_clip_algorithm¶ ( - Optional[- str]) – The gradient clipping algorithm to use. Pass- gradient_clip_algorithm="value"to clip by value, and- gradient_clip_algorithm="norm"to clip by norm. By default it will be set to- "norm".
- limit_train_batches¶ ( - Union[- int,- float,- None]) – How much of training dataset to check (float = fraction, int = num_batches). Default:- 1.0.
- limit_val_batches¶ ( - Union[- int,- float,- None]) – How much of validation dataset to check (float = fraction, int = num_batches). Default:- 1.0.
- limit_test_batches¶ ( - Union[- int,- float,- None]) – How much of test dataset to check (float = fraction, int = num_batches). Default:- 1.0.
- limit_predict_batches¶ ( - Union[- int,- float,- None]) – How much of prediction dataset to check (float = fraction, int = num_batches). Default:- 1.0.
- logger¶ ( - Union[- Logger,- Iterable[- Logger],- bool]) – Logger (or iterable collection of loggers) for experiment tracking. A- Truevalue uses the default- TensorBoardLoggerif it is installed, otherwise- CSVLogger.- Falsewill disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the- log_dirof the first logger. Default:- True.
- log_every_n_steps¶ ( - int) – How often to log within steps. Default:- 50.
- enable_progress_bar¶ ( - bool) – Whether to enable to progress bar by default. Default:- True.
- profiler¶ ( - Union[- Profiler,- str,- None]) – To profile individual steps during training and assist in identifying bottlenecks. Default:- None.
- overfit_batches¶ ( - Union[- int,- float]) – Overfit a fraction of training/validation data (float) or a set number of batches (int). Default:- 0.0.
- plugins¶ ( - Union[- PrecisionPlugin,- ClusterEnvironment,- CheckpointIO,- LayerSync,- str,- List[- Union[- PrecisionPlugin,- ClusterEnvironment,- CheckpointIO,- LayerSync,- str]],- None]) – Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default:- None.
- precision¶ ( - Union[- Literal[64, 32, 16],- Literal[‘64’, ‘32’, ‘16’, ‘bf16’]]) – Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). Can be used on CPU, GPU, TPUs, HPUs or IPUs. Default:- 32.
- max_epochs¶ ( - Optional[- int]) – Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to- max_epochs = 1000. To enable infinite training, set- max_epochs = -1.
- min_epochs¶ ( - Optional[- int]) – Force training for at least these many epochs. Disabled by default (None).
- max_steps¶ ( - int) – Stop training after this number of steps. Disabled by default (-1). If- max_steps = -1and- max_epochs = None, will default to- max_epochs = 1000. To enable infinite training, set- max_epochsto- -1.
- min_steps¶ ( - Optional[- int]) – Force training for at least these number of steps. Disabled by default (- None).
- max_time¶ ( - Union[- str,- timedelta,- Dict[- str,- int],- None]) – Stop training after this amount of time has passed. Disabled by default (- None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a- datetime.timedelta, or a dictionary with keys that will be passed to- datetime.timedelta.
- num_nodes¶ ( - int) – Number of GPU nodes for distributed training. Default:- 1.
- num_processes¶ ( - Optional[- int]) –- Number of processes for distributed training with - accelerator="cpu". Default:- 1.- Deprecated since version v1.7: - num_processeshas been deprecated in v1.7 and will be removed in v2.0. Please use- accelerator='cpu'and- devices=xinstead.
- num_sanity_val_steps¶ ( - int) – Sanity check runs n validation batches before starting the training routine. Set it to -1 to run all batches in all validation dataloaders. Default:- 2.
- reload_dataloaders_every_n_epochs¶ ( - int) – Set to a non-negative integer to reload dataloaders every n epochs. Default:- 0.
- replace_sampler_ddp¶ ( - bool) – Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add- shuffle=Truefor train sampler and- shuffle=Falsefor val/test sampler. If you want to customize it, you can set- replace_sampler_ddp=Falseand add your own distributed sampler.
- resume_from_checkpoint¶ ( - Union[- str,- Path,- None]) –- Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. - Deprecated since version v1.5: - resume_from_checkpointis deprecated in v1.5 and will be removed in v2.0. Please pass the path to- Trainer.fit(..., ckpt_path=...)instead.
- strategy¶ ( - Union[- str,- Strategy,- None]) – Supports different training strategies with aliases as well custom strategies. Default:- None.
- sync_batchnorm¶ ( - bool) – Synchronize batch norm layers between process groups/whole world. Default:- False.
- tpu_cores¶ ( - Union[- List[- int],- str,- int,- None]) –- How many TPU cores to train on (1 or 8) / Single TPU to train on (1) Default: - None.- Deprecated since version v1.7: - tpu_coreshas been deprecated in v1.7 and will be removed in v2.0. Please use- accelerator='tpu'and- devices=xinstead.
- How many IPUs to train on. Default: - None.- Deprecated since version v1.7: - ipushas been deprecated in v1.7 and will be removed in v2.0. Please use- accelerator='ipu'and- devices=xinstead.
- track_grad_norm¶ ( - Union[- int,- float,- str]) – -1 no tracking. Otherwise tracks that p-norm. May be set to ‘inf’ infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them. Default:- -1.
- val_check_interval¶ ( - Union[- int,- float,- None]) – How often to check the validation set. Pass a- floatin the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an- intto check after a fixed number of training batches. An- intvalue can only be higher than the number of training batches when- check_val_every_n_epoch=None, which validates after every- Ntraining batches across epochs or during iteration-based training. Default:- 1.0.
- enable_model_summary¶ ( - bool) – Whether to enable model summarization by default. Default:- True.
- move_metrics_to_cpu¶ ( - bool) – Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention. Default:- False.
- multiple_trainloader_mode¶ ( - str) – How to loop over the datasets when there are multiple train loaders. In ‘max_size_cycle’ mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In ‘min_size’ mode, all the datasets reload when reaching the minimum length of datasets. Default:- "max_size_cycle".
- inference_mode¶ ( - bool) – Whether to use- torch.inference_mode()or- torch.no_grad()during evaluation (- validate/- test/- predict).
 
 - fit(model, train_dataloaders=None, val_dataloaders=None, datamodule=None, ckpt_path=None)[source]¶
- Runs the full optimization routine. - Parameters
- model¶ ( - LightningModule) – Model to fit.
- train_dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- Sequence[- Sequence[- DataLoader]],- Sequence[- Dict[- str,- DataLoader]],- Dict[- str,- DataLoader],- Dict[- str,- Dict[- str,- DataLoader]],- Dict[- str,- Sequence[- DataLoader]],- LightningDataModule,- None]) – A collection of- torch.utils.data.DataLoaderor a- LightningDataModulespecifying training samples. In the case of multiple dataloaders, please see this section.
- val_dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- None]) – A- torch.utils.data.DataLoaderor a sequence of them specifying validation samples.
- ckpt_path¶ ( - Optional[- str]) – Path/URL of the checkpoint from which training is resumed. Could also be one of two special keywords- "last"and- "hpc". If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch.
- datamodule¶ ( - Optional[- LightningDataModule]) – An instance of- LightningDataModule.
 
- Return type
 
 - predict(model=None, dataloaders=None, datamodule=None, return_predictions=None, ckpt_path=None)[source]¶
- Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. - Parameters
- model¶ ( - Optional[- LightningModule]) – The model to predict with.
- dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- LightningDataModule,- None]) – A- torch.utils.data.DataLoaderor a sequence of them, or a- LightningDataModulespecifying prediction samples.
- datamodule¶ ( - Optional[- LightningDataModule]) – The datamodule with a predict_dataloader method that returns one or more dataloaders.
- return_predictions¶ ( - Optional[- bool]) – Whether to return predictions.- Trueby default except when an accelerator that spawns processes is used (not supported).
- ckpt_path¶ ( - Optional[- str]) – Either- "best",- "last",- "hpc"or path to the checkpoint you wish to predict. If- Noneand the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous- trainer.fitcall will be loaded if a checkpoint callback is configured.
 
- Return type
- Returns
- Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. 
 - See Lightning inference section for more. 
 - reset_predict_dataloader(model=None)[source]¶
- Resets the predict dataloader and determines the number of batches. - Parameters
- model¶ ( - Optional[- LightningModule]) – The- LightningModuleif called outside of the trainer scope.
- Return type
 
 - reset_test_dataloader(model=None)[source]¶
- Resets the test dataloader and determines the number of batches. - Parameters
- model¶ ( - Optional[- LightningModule]) – The- LightningModuleif called outside of the trainer scope.
- Return type
 
 - reset_train_dataloader(model=None)[source]¶
- Resets the train dataloader and initialises required variables (number of batches, when to validate, etc.). - Parameters
- model¶ ( - Optional[- LightningModule]) – The- LightningModuleif calling this outside of the trainer scope.
- Return type
 
 - reset_val_dataloader(model=None)[source]¶
- Resets the validation dataloader and determines the number of batches. - Parameters
- model¶ ( - Optional[- LightningModule]) – The- LightningModuleif called outside of the trainer scope.
- Return type
 
 - save_checkpoint(filepath, weights_only=False, storage_options=None)[source]¶
- Runs routine to create a checkpoint. 
 - test(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)[source]¶
- Perform one evaluation epoch over the test set. It’s separated from fit to make sure you never run on your test set until you want to. - Parameters
- model¶ ( - Optional[- LightningModule]) – The model to test.
- dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- LightningDataModule,- None]) – A- torch.utils.data.DataLoaderor a sequence of them, or a- LightningDataModulespecifying test samples.
- ckpt_path¶ ( - Optional[- str]) – Either- "best",- "last",- "hpc"or path to the checkpoint you wish to test. If- Noneand the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous- trainer.fitcall will be loaded if a checkpoint callback is configured.
- datamodule¶ ( - Optional[- LightningDataModule]) – An instance of- LightningDataModule.
 
- Return type
- Returns
- List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like - test_step(),- test_epoch_end(), etc. The length of the list corresponds to the number of test dataloaders used.
 
 - tune(model, train_dataloaders=None, val_dataloaders=None, dataloaders=None, datamodule=None, scale_batch_size_kwargs=None, lr_find_kwargs=None, method='fit')[source]¶
- Runs routines to tune hyperparameters before training. - Parameters
- model¶ ( - LightningModule) – Model to tune.
- train_dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- Sequence[- Sequence[- DataLoader]],- Sequence[- Dict[- str,- DataLoader]],- Dict[- str,- DataLoader],- Dict[- str,- Dict[- str,- DataLoader]],- Dict[- str,- Sequence[- DataLoader]],- LightningDataModule,- None]) – A collection of- torch.utils.data.DataLoaderor a- LightningDataModulespecifying training samples. In the case of multiple dataloaders, please see this section.
- val_dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- None]) – A- torch.utils.data.DataLoaderor a sequence of them specifying validation samples.
- dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- None]) – A- torch.utils.data.DataLoaderor a sequence of them specifying val/test/predict samples used for running tuner on validation/testing/prediction.
- datamodule¶ ( - Optional[- LightningDataModule]) – An instance of- LightningDataModule.
- scale_batch_size_kwargs¶ ( - Optional[- Dict[- str,- Any]]) – Arguments for- scale_batch_size()
- lr_find_kwargs¶ ( - Optional[- Dict[- str,- Any]]) – Arguments for- lr_find()
- method¶ ( - Literal[‘fit’, ‘validate’, ‘test’, ‘predict’]) – Method to run tuner on. It can be any of- ("fit", "validate", "test", "predict").
 
- Return type
- _TunerResult
 
 - validate(model=None, dataloaders=None, ckpt_path=None, verbose=True, datamodule=None)[source]¶
- Perform one evaluation epoch over the validation set. - Parameters
- model¶ ( - Optional[- LightningModule]) – The model to validate.
- dataloaders¶ ( - Union[- DataLoader,- Sequence[- DataLoader],- LightningDataModule,- None]) – A- torch.utils.data.DataLoaderor a sequence of them, or a- LightningDataModulespecifying validation samples.
- ckpt_path¶ ( - Optional[- str]) – Either- "best",- "last",- "hpc"or path to the checkpoint you wish to validate. If- Noneand the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous- trainer.fitcall will be loaded if a checkpoint callback is configured.
- datamodule¶ ( - Optional[- LightningDataModule]) – An instance of- LightningDataModule.
 
- Return type
- Returns
- List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like - validation_step(),- validation_epoch_end(), etc. The length of the list corresponds to the number of validation dataloaders used.
 
 - property checkpoint_callback: Optional[pytorch_lightning.callbacks.checkpoint.Checkpoint]¶
- The first - ModelCheckpointcallback in the Trainer.callbacks list, or- Noneif it doesn’t exist.- Return type
- Optional[- Checkpoint]
 
 - property checkpoint_callbacks: List[pytorch_lightning.callbacks.checkpoint.Checkpoint]¶
- A list of all instances of - ModelCheckpointfound in the Trainer.callbacks list.- Return type
- List[- Checkpoint]
 
 - property ckpt_path: Optional[str]¶
- Set to the path/URL of a checkpoint loaded via - fit(),- validate(),- test(), or- predict().- Noneotherwise.
 - property current_epoch: int¶
- The current epoch, updated after the epoch end hooks are run. - Return type
 
 - property default_root_dir: str¶
- The default location to save artifacts of loggers, checkpoints etc. - It is used as a fallback if logger or checkpoint callback do not define specific save paths. - Return type
 
 - property early_stopping_callback: Optional[pytorch_lightning.callbacks.early_stopping.EarlyStopping]¶
- The first - EarlyStoppingcallback in the Trainer.callbacks list, or- Noneif it doesn’t exist.- Return type
 
 - property early_stopping_callbacks: List[pytorch_lightning.callbacks.early_stopping.EarlyStopping]¶
- A list of all instances of - EarlyStoppingfound in the Trainer.callbacks list.- Return type
 
 - property estimated_stepping_batches: Union[int, float]¶
- Estimated stepping batches for the complete training inferred from DataLoaders, gradient accumulation factor and distributed setup. - Examples: - def configure_optimizers(self): optimizer = ... scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=1e-3, total_steps=self.trainer.estimated_stepping_batches ) return [optimizer], [scheduler] 
 - property global_step: int¶
- The number of optimizer steps taken (does not reset each epoch). - This includes multiple optimizers and TBPTT steps (if enabled). - Return type
 
 - property model: Optional[torch.nn.modules.module.Module]¶
- The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. - To access the pure LightningModule, use - lightning_module()instead.
 - property prediction_writer_callbacks: List[pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter]¶
- A list of all instances of - BasePredictionWriterfound in the Trainer.callbacks list.- Return type
 
 - property progress_bar_callback: Optional[pytorch_lightning.callbacks.progress.base.ProgressBarBase]¶
- An instance of - ProgressBarBasefound in the Trainer.callbacks list, or- Noneif one doesn’t exist.- Return type