# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.# THIS FILE MUST READ EASILY, FOR UNDERSTANDING AND DEBUGGING PURPOSES.# DO NOT OBSCURE THE TRAINING LOOP# THIS IS A HARD REQUIREMENT TO CONTRIBUTING TO LIGHTNING# WE FAVOR READABILITY OVER ENGINEERING-CONSTRUCTS BY DESIGN# DO NOT REMOVE THIS NOTICE# - WILLIAM FALCON"""Trainer to automate the training."""importloggingimportmathimportosfromcontextlibimportcontextmanagerfromdatetimeimporttimedeltafromtypingimportAny,Dict,Generator,Iterable,List,Optional,Unionfromweakrefimportproxyimporttorchfromtorch.optimimportOptimizerimportlightning.pytorchasplfromlightning.fabric.utilities.apply_funcimportconvert_tensors_to_scalarsfromlightning.fabric.utilities.cloud_ioimport_is_local_file_protocolfromlightning.fabric.utilities.typesimport_PATHfromlightning.pytorch.acceleratorsimportAcceleratorfromlightning.pytorch.callbacksimportCallback,Checkpoint,EarlyStopping,ProgressBarfromlightning.pytorch.core.datamoduleimportLightningDataModulefromlightning.pytorch.loggersimportLoggerfromlightning.pytorch.loggers.csv_logsimportCSVLoggerfromlightning.pytorch.loggers.tensorboardimportTensorBoardLoggerfromlightning.pytorch.loggers.utilitiesimport_log_hyperparamsfromlightning.pytorch.loopsimport_PredictionLoop,_TrainingEpochLoopfromlightning.pytorch.loops.evaluation_loopimport_EvaluationLoopfromlightning.pytorch.loops.fit_loopimport_FitLoopfromlightning.pytorch.loops.utilitiesimport_parse_loop_limits,_reset_progressfromlightning.pytorch.pluginsimport_PLUGIN_INPUT,Precisionfromlightning.pytorch.profilersimportProfilerfromlightning.pytorch.strategiesimportParallelStrategy,Strategyfromlightning.pytorch.trainerimportcall,setupfromlightning.pytorch.trainer.configuration_validatorimport_verify_loop_configurationsfromlightning.pytorch.trainer.connectors.accelerator_connectorimport(_LITERAL_WARN,_PRECISION_INPUT,_PRECISION_INPUT_STR,_AcceleratorConnector,)fromlightning.pytorch.trainer.connectors.callback_connectorimport_CallbackConnectorfromlightning.pytorch.trainer.connectors.checkpoint_connectorimport_CheckpointConnectorfromlightning.pytorch.trainer.connectors.data_connectorimport_DataConnectorfromlightning.pytorch.trainer.connectors.logger_connectorimport_LoggerConnectorfromlightning.pytorch.trainer.connectors.logger_connector.resultimport_OUT_DICT,_PBAR_DICT,_ResultCollectionfromlightning.pytorch.trainer.connectors.signal_connectorimport_SignalConnectorfromlightning.pytorch.trainer.statesimportRunningStage,TrainerFn,TrainerState,TrainerStatusfromlightning.pytorch.utilitiesimportGradClipAlgorithmType,parsingfromlightning.pytorch.utilities.argparseimport_defaults_from_env_varsfromlightning.pytorch.utilities.compileimport_maybe_unwrap_optimized,_verify_strategy_supports_compilefromlightning.pytorch.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorch.utilities.model_helpersimportis_overriddenfromlightning.pytorch.utilities.rank_zeroimportrank_zero_info,rank_zero_warnfromlightning.pytorch.utilities.seedimportisolate_rngfromlightning.pytorch.utilities.typesimport(_EVALUATE_OUTPUT,_PREDICT_OUTPUT,EVAL_DATALOADERS,TRAIN_DATALOADERS,LRSchedulerConfig,)fromlightning.pytorch.utilities.warningsimportPossibleUserWarninglog=logging.getLogger(__name__)
[docs]@_defaults_from_env_varsdef__init__(self,*,accelerator:Union[str,Accelerator]="auto",strategy:Union[str,Strategy]="auto",devices:Union[List[int],str,int]="auto",num_nodes:int=1,precision:Optional[_PRECISION_INPUT]=None,logger:Optional[Union[Logger,Iterable[Logger],bool]]=None,callbacks:Optional[Union[List[Callback],Callback]]=None,fast_dev_run:Union[int,bool]=False,max_epochs:Optional[int]=None,min_epochs:Optional[int]=None,max_steps:int=-1,min_steps:Optional[int]=None,max_time:Optional[Union[str,timedelta,Dict[str,int]]]=None,limit_train_batches:Optional[Union[int,float]]=None,limit_val_batches:Optional[Union[int,float]]=None,limit_test_batches:Optional[Union[int,float]]=None,limit_predict_batches:Optional[Union[int,float]]=None,overfit_batches:Union[int,float]=0.0,val_check_interval:Optional[Union[int,float]]=None,check_val_every_n_epoch:Optional[int]=1,num_sanity_val_steps:Optional[int]=None,log_every_n_steps:Optional[int]=None,enable_checkpointing:Optional[bool]=None,enable_progress_bar:Optional[bool]=None,enable_model_summary:Optional[bool]=None,accumulate_grad_batches:int=1,gradient_clip_val:Optional[Union[int,float]]=None,gradient_clip_algorithm:Optional[str]=None,deterministic:Optional[Union[bool,_LITERAL_WARN]]=None,benchmark:Optional[bool]=None,inference_mode:bool=True,use_distributed_sampler:bool=True,profiler:Optional[Union[Profiler,str]]=None,detect_anomaly:bool=False,barebones:bool=False,plugins:Optional[Union[_PLUGIN_INPUT,List[_PLUGIN_INPUT]]]=None,sync_batchnorm:bool=False,reload_dataloaders_every_n_epochs:int=0,default_root_dir:Optional[_PATH]=None,)->None:r"""Customize every aspect of training via flags. Args: accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto") as well as custom accelerator instances. strategy: Supports different training strategies with aliases as well custom strategies. Default: ``"auto"``. devices: The devices to use. Can be set to a positive number (int or str), a sequence of device indices (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for automatic selection based on the chosen accelerator. Default: ``"auto"``. num_nodes: Number of GPU nodes for distributed training. Default: ``1``. precision: Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'), 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed'). Can be used on CPU, GPU, TPUs, or HPUs. Default: ``'32-true'``. logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``. ``False`` will disable logging. If multiple loggers are provided, local files (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger. Default: ``True``. callbacks: Add a callback or list of callbacks. Default: ``None``. fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). Default: ``False``. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``. min_epochs: Force training for at least these many epochs. Disabled by default (None). max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``. min_steps: Force training for at least these number of steps. Disabled by default (``None``). max_time: Stop training after this amount of time has passed. Disabled by default (``None``). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`. limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). Default: ``1.0``. limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). Default: ``1.0``. overfit_batches: Overfit a fraction of training/validation data (float) or a set number of batches (int). Default: ``0.0``. val_check_interval: How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training batches. An ``int`` value can only be higher than the number of training batches when ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches across epochs or during iteration-based training. Default: ``1.0``. check_val_every_n_epoch: Perform a validation loop after every `N` training epochs. If ``None``, validation will be done solely based on the number of training batches, requiring ``val_check_interval`` to be an integer value. Default: ``1``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. Default: ``2``. log_every_n_steps: How often to log within steps. Default: ``50``. enable_checkpointing: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`. Default: ``True``. enable_progress_bar: Whether to enable to progress bar by default. Default: ``True``. enable_model_summary: Whether to enable model summarization by default. Default: ``True``. accumulate_grad_batches: Accumulates gradients over k batches before stepping the optimizer. Default: 1. gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. Default: ``None``. gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``. deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. benchmark: The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to. The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic` is set to ``True``, this will default to ``False``. Override to manually set a different value. Default: ``None``. inference_mode: Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during evaluation (``validate``/``test``/``predict``). use_distributed_sampler: Whether to wrap the DataLoader's sampler with :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed sampler was already added, Lightning will not replace the existing one. For iterable-style datasets, we don't do this automatically. profiler: To profile individual steps during training and assist in identifying bottlenecks. Default: ``None``. detect_anomaly: Enable anomaly detection for the autograd engine. Default: ``False``. barebones: Whether to run in "barebones mode", where all features that may impact raw speed are disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training runs. The following features are deactivated: :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`, :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`, :meth:`~lightning.pytorch.core.LightningModule.log`, :meth:`~lightning.pytorch.core.LightningModule.log_dict`. plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. Default: ``None``. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. Default: ``False``. reload_dataloaders_every_n_epochs: Set to a positive integer to reload dataloaders every n epochs. Default: ``0``. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Raises: TypeError: If ``gradient_clip_val`` is not an int or float. MisconfigurationException: If ``gradient_clip_algorithm`` is invalid. """super().__init__()log.debug(f"{self.__class__.__name__}: Initializing trainer with parameters: {locals()}")ifdefault_root_dirisnotNone:default_root_dir=os.fspath(default_root_dir)self.barebones=barebonesifbarebones:# opt-outsifenable_checkpointing:raiseValueError(f"`Trainer(barebones=True, enable_checkpointing={enable_checkpointing!r})` was passed."" Checkpointing can impact raw speed so it is disabled in barebones mode.")enable_checkpointing=FalseifloggerisnotNoneandloggerisnotFalse:raiseValueError(f"`Trainer(barebones=True, logger={logger!r})` was passed."" Logging can impact raw speed so it is disabled in barebones mode.")logger=Falseifenable_progress_bar:raiseValueError(f"`Trainer(barebones=True, enable_progress_bar={enable_progress_bar!r})` was passed."" The progress bar can impact raw speed so it is disabled in barebones mode.")enable_progress_bar=Falseiflog_every_n_stepsisnotNoneandlog_every_n_steps!=0:raiseValueError(f"`Trainer(barebones=True, log_every_n_steps={log_every_n_steps!r})` was passed."" Logging can impact raw speed so it is disabled in barebones mode.")log_every_n_steps=0ifenable_model_summary:raiseValueError(f"`Trainer(barebones=True, enable_model_summary={enable_model_summary!r})` was passed."" Model summary can impact raw speed so it is disabled in barebones mode.")enable_model_summary=Falseifnum_sanity_val_stepsisnotNoneandnum_sanity_val_steps!=0:raiseValueError(f"`Trainer(barebones=True, num_sanity_val_steps={num_sanity_val_steps!r})` was passed."" Sanity checking can impact raw speed so it is disabled in barebones mode.")num_sanity_val_steps=0# opt-insiffast_dev_runisnotFalseandfast_dev_run!=0:raiseValueError(f"`Trainer(barebones=True, fast_dev_run={fast_dev_run!r})` was passed."" Development run is not meant for raw speed evaluation so it is disabled in barebones mode.")ifdetect_anomaly:raiseValueError(f"`Trainer(barebones=True, detect_anomaly={detect_anomaly!r})` was passed."" Anomaly detection can impact raw speed so it is disabled in barebones mode.")ifprofilerisnotNone:raiseValueError(f"`Trainer(barebones=True, profiler={profiler!r})` was passed."" Profiling can impact raw speed so it is disabled in barebones mode.")deactivated=(" - Checkpointing: `Trainer(enable_checkpointing=True)`"," - Progress bar: `Trainer(enable_progress_bar=True)`"," - Model summary: `Trainer(enable_model_summary=True)`"," - Logging: `Trainer(logger=True)`, `Trainer(log_every_n_steps>0)`,"" `LightningModule.log(...)`, `LightningModule.log_dict(...)`"," - Sanity checking: `Trainer(num_sanity_val_steps>0)`"," - Development run: `Trainer(fast_dev_run=True)`"," - Anomaly detection: `Trainer(detect_anomaly=True)`"," - Profiling: `Trainer(profiler=...)`",)rank_zero_info("You are running in `Trainer(barebones=True)` mode. All features that may impact raw speed have been"" disabled to facilitate analyzing the Trainer overhead. Specifically, the following features are"f" deactivated:{os.linesep}{os.linesep.join(deactivated)}")else:# set the opt-out defaultsifenable_checkpointingisNone:enable_checkpointing=TrueifloggerisNone:logger=Trueifenable_progress_barisNone:enable_progress_bar=Trueiflog_every_n_stepsisNone:log_every_n_steps=50ifenable_model_summaryisNone:enable_model_summary=Trueifnum_sanity_val_stepsisNone:num_sanity_val_steps=2# init connectorsself._data_connector=_DataConnector(self)self._accelerator_connector=_AcceleratorConnector(devices=devices,accelerator=accelerator,strategy=strategy,num_nodes=num_nodes,sync_batchnorm=sync_batchnorm,benchmark=benchmark,use_distributed_sampler=use_distributed_sampler,deterministic=deterministic,precision=precision,plugins=plugins,)self._logger_connector=_LoggerConnector(self)self._callback_connector=_CallbackConnector(self)self._checkpoint_connector=_CheckpointConnector(self)self._signal_connector=_SignalConnector(self)# init loopsself.fit_loop=_FitLoop(self,min_epochs=min_epochs,max_epochs=max_epochs)self.fit_loop.epoch_loop=_TrainingEpochLoop(self,min_steps=min_steps,max_steps=max_steps)self.validate_loop=_EvaluationLoop(self,TrainerFn.VALIDATING,RunningStage.VALIDATING,inference_mode=inference_mode)self.test_loop=_EvaluationLoop(self,TrainerFn.TESTING,RunningStage.TESTING,inference_mode=inference_mode)self.predict_loop=_PredictionLoop(self,inference_mode=inference_mode)self.accumulate_grad_batches=accumulate_grad_batches# init callbacks# Declare attributes to be set in _callback_connector on_trainer_initself._callback_connector.on_trainer_init(callbacks,enable_checkpointing,enable_progress_bar,default_root_dir,enable_model_summary,max_time,)# init data flagsself.check_val_every_n_epoch:Optional[int]self._data_connector.on_trainer_init(val_check_interval,reload_dataloaders_every_n_epochs,check_val_every_n_epoch,)# gradient clippingifgradient_clip_valisnotNoneandnotisinstance(gradient_clip_val,(int,float)):raiseTypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")ifgradient_clip_algorithmisnotNoneandnotGradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):raiseMisconfigurationException(f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}.")self.gradient_clip_val:Optional[Union[int,float]]=gradient_clip_valself.gradient_clip_algorithm:Optional[GradClipAlgorithmType]=(GradClipAlgorithmType(gradient_clip_algorithm.lower())ifgradient_clip_algorithmisnotNoneelseNone)ifdetect_anomaly:rank_zero_info("You have turned on `Trainer(detect_anomaly=True)`. This will significantly slow down compute speed and"" is recommended only for model debugging.")self._detect_anomaly:bool=detect_anomalysetup._log_device_info(self)self.should_stop=Falseself.state=TrainerState()# configure profilersetup._init_profiler(self,profiler)# init logger flagsself._loggers:List[Logger]self._logger_connector.on_trainer_init(logger,log_every_n_steps)# init debugging flagsself.val_check_batch:Union[int,float]self.val_check_interval:Union[int,float]self.num_sanity_val_steps:Union[int,float]self.limit_train_batches:Union[int,float]self.limit_val_batches:Union[int,float]self.limit_test_batches:Union[int,float]self.limit_predict_batches:Union[int,float]setup._init_debugging_flags(self,limit_train_batches,limit_val_batches,limit_test_batches,limit_predict_batches,fast_dev_run,overfit_batches,val_check_interval,num_sanity_val_steps,)
[docs]deffit(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,ckpt_path:Optional[_PATH]=None,)->None:r"""Runs the full optimization routine. Args: model: Model to fit. train_dataloaders: An iterable or collection of iterables specifying training samples. Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. val_dataloaders: An iterable or collection of iterables specifying validation samples. datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.train_dataloader` hook. ckpt_path: Path/URL of the checkpoint from which training is resumed. Could also be one of two special keywords ``"last"`` and ``"hpc"``. If there is no checkpoint file at the path, an exception is raised. Raises: TypeError: If ``model`` is not :class:`~lightning.pytorch.core.LightningModule` for torch version less than 2.0.0 and if ``model`` is not :class:`~lightning.pytorch.core.LightningModule` or :class:`torch._dynamo.OptimizedModule` for torch versions greater than or equal to 2.0.0 . For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. """model=_maybe_unwrap_optimized(model)self.strategy._lightning_module=model_verify_strategy_supports_compile(model,self.strategy)self.state.fn=TrainerFn.FITTINGself.state.status=TrainerStatus.RUNNINGself.training=Truecall._call_and_handle_interrupt(self,self._fit_impl,model,train_dataloaders,val_dataloaders,datamodule,ckpt_path)
def_fit_impl(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,ckpt_path:Optional[_PATH]=None,)->None:log.debug(f"{self.__class__.__name__}: trainer fit stage")# if a datamodule comes in as the second arg, then fix it for the userifisinstance(train_dataloaders,LightningDataModule):datamodule=train_dataloaderstrain_dataloaders=None# If you supply a datamodule you can't supply train_dataloader or val_dataloadersif(train_dataloadersisnotNoneorval_dataloadersisnotNone)anddatamoduleisnotNone:raiseMisconfigurationException("You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`")# links data to the trainerself._data_connector.attach_data(model,train_dataloaders=train_dataloaders,val_dataloaders=val_dataloaders,datamodule=datamodule)assertself.state.fnisnotNoneckpt_path=self._checkpoint_connector._select_ckpt_path(self.state.fn,ckpt_path,model_provided=True,model_connected=self.lightning_moduleisnotNone,)self._run(model,ckpt_path=ckpt_path)assertself.state.stoppedself.training=Falsereturn
[docs]defvalidate(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[_PATH]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->_EVALUATE_OUTPUT:r"""Perform one evaluation epoch over the validation set. Args: model: The model to validate. dataloaders: An iterable or collection of iterables specifying validation samples. Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to validate. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the validation results. datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.val_dataloader` hook. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like :meth:`~lightning.pytorch.LightningModule.validation_step` etc. The length of the list corresponds to the number of validation dataloaders used. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. """ifmodelisNone:# do we still have a reference from a previous call?ifself.lightning_moduleisNone:raiseTypeError("`Trainer.validate()` requires a `LightningModule` when it hasn't been passed in a previous run")else:model=_maybe_unwrap_optimized(model)self.strategy._lightning_module=model_verify_strategy_supports_compile(self.lightning_module,self.strategy)self.state.fn=TrainerFn.VALIDATINGself.state.status=TrainerStatus.RUNNINGself.validating=Truereturncall._call_and_handle_interrupt(self,self._validate_impl,model,dataloaders,ckpt_path,verbose,datamodule)
def_validate_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[_PATH]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->Optional[Union[_PREDICT_OUTPUT,_EVALUATE_OUTPUT]]:# --------------------# SETUP HOOK# --------------------log.debug(f"{self.__class__.__name__}: trainer validate stage")# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply val_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")ifmodelisNone:model=self.lightning_modulemodel_provided=Falseelse:model_provided=Trueself.validate_loop.verbose=verbose# links data to the trainerself._data_connector.attach_data(model,val_dataloaders=dataloaders,datamodule=datamodule)assertself.state.fnisnotNoneckpt_path=self._checkpoint_connector._select_ckpt_path(self.state.fn,ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)results=self._run(model,ckpt_path=ckpt_path)# remove the tensors from the validation resultsresults=convert_tensors_to_scalars(results)assertself.state.stoppedself.validating=Falsereturnresults
[docs]deftest(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[_PATH]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->_EVALUATE_OUTPUT:r"""Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. Args: model: The model to test. dataloaders: An iterable or collection of iterables specifying test samples. Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to test. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the test results. datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.test_dataloader` hook. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like :meth:`~lightning.pytorch.LightningModule.test_step` etc. The length of the list corresponds to the number of test dataloaders used. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. """ifmodelisNone:# do we still have a reference from a previous call?ifself.lightning_moduleisNone:raiseTypeError("`Trainer.test()` requires a `LightningModule` when it hasn't been passed in a previous run")else:model=_maybe_unwrap_optimized(model)self.strategy._lightning_module=model_verify_strategy_supports_compile(self.lightning_module,self.strategy)self.state.fn=TrainerFn.TESTINGself.state.status=TrainerStatus.RUNNINGself.testing=Truereturncall._call_and_handle_interrupt(self,self._test_impl,model,dataloaders,ckpt_path,verbose,datamodule)
def_test_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[_PATH]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->Optional[Union[_PREDICT_OUTPUT,_EVALUATE_OUTPUT]]:# --------------------# SETUP HOOK# --------------------log.debug(f"{self.__class__.__name__}: trainer test stage")# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply test_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")ifmodelisNone:model=self.lightning_modulemodel_provided=Falseelse:model_provided=Trueself.test_loop.verbose=verbose# links data to the trainerself._data_connector.attach_data(model,test_dataloaders=dataloaders,datamodule=datamodule)assertself.state.fnisnotNoneckpt_path=self._checkpoint_connector._select_ckpt_path(self.state.fn,ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)results=self._run(model,ckpt_path=ckpt_path)# remove the tensors from the test resultsresults=convert_tensors_to_scalars(results)assertself.state.stoppedself.testing=Falsereturnresults
[docs]defpredict(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,datamodule:Optional[LightningDataModule]=None,return_predictions:Optional[bool]=None,ckpt_path:Optional[_PATH]=None,)->Optional[_PREDICT_OUTPUT]:r"""Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. Args: model: The model to predict with. dataloaders: An iterable or collection of iterables specifying predict samples. Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. datamodule: A :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines the :class:`~lightning.pytorch.core.hooks.DataHooks.predict_dataloader` hook. return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). ckpt_path: Either ``"best"``, ``"last"``, ``"hpc"`` or path to the checkpoint you wish to predict. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. For more information about multiple dataloaders, see this :ref:`section <multiple-dataloaders>`. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. Raises: TypeError: If no ``model`` is passed and there was no ``LightningModule`` passed in the previous run. If ``model`` passed is not `LightningModule` or `torch._dynamo.OptimizedModule`. MisconfigurationException: If both ``dataloaders`` and ``datamodule`` are passed. Pass only one of these. RuntimeError: If a compiled ``model`` is passed and the strategy is not supported. See :ref:`Lightning inference section<deploy/production_basic:Predict step with your LightningModule>` for more. """ifmodelisNone:# do we still have a reference from a previous call?ifself.lightning_moduleisNone:raiseTypeError("`Trainer.predict()` requires a `LightningModule` when it hasn't been passed in a previous run")else:model=_maybe_unwrap_optimized(model)self.strategy._lightning_module=model_verify_strategy_supports_compile(self.lightning_module,self.strategy)self.state.fn=TrainerFn.PREDICTINGself.state.status=TrainerStatus.RUNNINGself.predicting=Truereturncall._call_and_handle_interrupt(self,self._predict_impl,model,dataloaders,datamodule,return_predictions,ckpt_path)
def_predict_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,datamodule:Optional[LightningDataModule]=None,return_predictions:Optional[bool]=None,ckpt_path:Optional[_PATH]=None,)->Optional[_PREDICT_OUTPUT]:# --------------------# SETUP HOOK# --------------------log.debug(f"{self.__class__.__name__}: trainer predict stage")self.predict_loop.return_predictions=return_predictions# type: ignore[assignment]# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=NoneifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")ifmodelisNone:model=self.lightning_modulemodel_provided=Falseelse:model_provided=True# links data to the trainerself._data_connector.attach_data(model,predict_dataloaders=dataloaders,datamodule=datamodule)assertself.state.fnisnotNoneckpt_path=self._checkpoint_connector._select_ckpt_path(self.state.fn,ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)results=self._run(model,ckpt_path=ckpt_path)assertself.state.stoppedself.predicting=Falsereturnresultsdef_run(self,model:"pl.LightningModule",ckpt_path:Optional[_PATH]=None)->Optional[Union[_EVALUATE_OUTPUT,_PREDICT_OUTPUT]]:ifself.state.fn==TrainerFn.FITTING:min_epochs,max_epochs=_parse_loop_limits(self.min_steps,self.max_steps,self.min_epochs,self.max_epochs,self)self.fit_loop.min_epochs=min_epochsself.fit_loop.max_epochs=max_epochsifself.barebones:# no progress bar in barebones can make it look like the Trainer hungrank_zero_info("`Trainer(barebones=True)` started running. The progress bar is disabled so you might want to"" manually print the progress in your model.")# clean hparamsifhasattr(model,"hparams"):parsing.clean_namespace(model.hparams)# attach model to the strategyself.strategy.connect(model)self._callback_connector._attach_model_callbacks()self._callback_connector._attach_model_logging_functions()_verify_loop_configurations(self)# ----------------------------# SET UP THE TRAINER# ----------------------------log.debug(f"{self.__class__.__name__}: setting up strategy environment")self.strategy.setup_environment()self.__setup_profiler()log.debug(f"{self.__class__.__name__}: preparing data")self._data_connector.prepare_data()call._call_setup_hook(self)# allow user to set up LightningModule in accelerator environmentlog.debug(f"{self.__class__.__name__}: configuring model")call._call_configure_model(self)# check if we should delay restoring checkpoint till laterifnotself.strategy.restore_checkpoint_after_setup:log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)# reset logger connectorself._logger_connector.reset_results()self._logger_connector.reset_metrics()# strategy will configure model and move it to the deviceself.strategy.setup(self)# hookifself.state.fn==TrainerFn.FITTING:call._call_callback_hooks(self,"on_fit_start")call._call_lightning_module_hook(self,"on_fit_start")_log_hyperparams(self)ifself.strategy.restore_checkpoint_after_setup:log.debug(f"{self.__class__.__name__}: restoring module and callbacks from checkpoint path: {ckpt_path}")self._checkpoint_connector._restore_modules_and_callbacks(ckpt_path)# restore optimizers, etc.log.debug(f"{self.__class__.__name__}: restoring training state")self._checkpoint_connector.restore_training_state()self._checkpoint_connector.resume_end()self._signal_connector.register_signal_handlers()# ----------------------------# RUN THE TRAINER# ----------------------------results=self._run_stage()# ----------------------------# POST-Training CLEAN UP# ----------------------------log.debug(f"{self.__class__.__name__}: trainer tearing down")self._teardown()ifself.state.fn==TrainerFn.FITTING:call._call_callback_hooks(self,"on_fit_end")call._call_lightning_module_hook(self,"on_fit_end")log.debug(f"{self.__class__.__name__}: calling teardown hooks")call._call_teardown_hook(self)self.state.status=TrainerStatus.FINISHEDself.state.stage=Nonereturnresultsdef_teardown(self)->None:"""This is the Trainer's internal teardown, unrelated to the `teardown` hooks in LightningModule and Callback; those are handled by :meth:`_call_teardown_hook`."""self.strategy.teardown()loop=self._active_loop# loop should never be `None` here but it can because we don't know the trainer stage with `ddp_spawn`ifloopisnotNone:loop.teardown()self._logger_connector.teardown()self._signal_connector.teardown()def_run_stage(self)->Optional[Union[_PREDICT_OUTPUT,_EVALUATE_OUTPUT]]:# wait for all to join if on distributedself.strategy.barrier("run-stage")self.lightning_module.zero_grad()ifself.evaluating:returnself._evaluation_loop.run()ifself.predicting:returnself.predict_loop.run()ifself.training:withisolate_rng():self._run_sanity_check()withtorch.autograd.set_detect_anomaly(self._detect_anomaly):self.fit_loop.run()returnNoneraiseRuntimeError(f"Unexpected state {self.state}")def_run_sanity_check(self)->None:val_loop=self.fit_loop.epoch_loop.val_loopshould_sanity_check=(self.enable_validationandself.num_sanity_val_steps>0# do not sanity check if restarting because it would mess up the loaded stateandnotval_loop.restarting)# run tiny validation (if validation defined)# to make sure program won't crash during valifshould_sanity_check:stage=self.state.stageself.sanity_checking=True# reset logger connectorself._logger_connector.reset_results()self._logger_connector.reset_metrics()call._call_callback_hooks(self,"on_sanity_check_start")# run eval stepval_loop.run()call._call_callback_hooks(self,"on_sanity_check_end")# reset logger connectorself._logger_connector.reset_results()self._logger_connector.reset_metrics()# reset the progress tracking state after sanity checking. we don't need to set the state before# because sanity check only runs when we are not restarting_reset_progress(val_loop)# restore the previous stage when the sanity check if finishedself.state.stage=stagedef__setup_profiler(self)->None:assertself.state.fnisnotNonelocal_rank=self.local_rankifself.world_size>1elseNoneself.profiler._lightning_module=proxy(self.lightning_module)self.profiler.setup(stage=self.state.fn,local_rank=local_rank,log_dir=self.log_dir)
[docs]@contextmanagerdefinit_module(self,empty_init:Optional[bool]=None)->Generator:"""Tensors that you instantiate under this context manager will be created on the device right away and have the right data type depending on the precision setting in the Trainer. The parameters and tensors get created on the device and with the right data type right away without wasting memory being allocated unnecessarily. Args: empty_init: Whether to initialize the model with empty weights (uninitialized memory). If ``None``, the strategy will decide. Some strategies may not support all options. Set this to ``True`` if you are loading a checkpoint into a large model. """ifis_overridden("model_sharded_context",self.strategy,parent=Strategy):# warning instead of error so that code changes are not required when changing strategies# this is a limitation because processes are not expected to have been launched when this is calledrank_zero_warn(f"`trainer.init_module` cannot fully support proper instantiation of your model with the"f" `{type(self.strategy).__name__}` strategy. Please instantiate your model inside the"f"`LightningModule.configure_model` hook instead",# ideally we would check if `configure_model` is already overridden, but we don't have a reliable# reference to the model yetcategory=PossibleUserWarning,)withself.strategy.tensor_init_context(empty_init=empty_init):yield
[docs]defprint(self,*args:Any,**kwargs:Any)->None:"""Print something only on the first process. If running on multiple machines, it will print from the first process in each machine. Arguments passed to this method are forwarded to the Python built-in :func:`print` function. """ifself.local_rank==0:print(*args,**kwargs)
""" Accelerator properties """@propertydefaccelerator(self)->Accelerator:assertself.strategy.acceleratorreturnself.strategy.accelerator@propertydefstrategy(self)->Strategy:returnself._accelerator_connector.strategy@propertydefprecision_plugin(self)->Precision:returnself.strategy.precision_plugin@propertydefglobal_rank(self)->int:returnself.strategy.global_rank@propertydeflocal_rank(self)->int:# some strategies define a local rankreturngetattr(self.strategy,"local_rank",0)@propertydefnode_rank(self)->int:# some strategies define a node rankreturngetattr(self.strategy,"node_rank",0)@propertydefworld_size(self)->int:# some strategies define a world sizereturngetattr(self.strategy,"world_size",1)@propertydefnum_nodes(self)->int:returngetattr(self.strategy,"num_nodes",1)@propertydefdevice_ids(self)->List[int]:"""List of device indexes per node."""devices=(self.strategy.parallel_devicesifisinstance(self.strategy,ParallelStrategy)else[self.strategy.root_device])assertdevicesisnotNonedevice_ids=[]foridx,deviceinenumerate(devices):ifisinstance(device,torch.device):device_ids.append(device.indexoridx)elifisinstance(device,int):device_ids.append(device)returndevice_ids@propertydefnum_devices(self)->int:"""Number of devices the trainer uses per node."""returnlen(self.device_ids)@propertydeflightning_module(self)->"pl.LightningModule":# TODO: this is actually an optional returnreturnself.strategy.lightning_module# type: ignore[return-value]@propertydefoptimizers(self)->List[Optimizer]:returnself.strategy.optimizers@optimizers.setterdefoptimizers(self,new_optims:List[Optimizer])->None:self.strategy.optimizers=new_optims@propertydeflr_scheduler_configs(self)->List[LRSchedulerConfig]:returnself.strategy.lr_scheduler_configs@propertydefprecision(self)->_PRECISION_INPUT_STR:returnself.strategy.precision_plugin.precision@propertydefscaler(self)->Optional[Any]:returngetattr(self.precision_plugin,"scaler",None)@propertydefmodel(self)->Optional[torch.nn.Module]:"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use :meth:`~lightning.pytorch.trainer.trainer.Trainer.lightning_module` instead. """returnself.strategy.model""" General properties """@propertydeflog_dir(self)->Optional[str]:"""The directory for the current experiment. Use this to save images to, etc... .. note:: You must call this on all processes. Failing to do so will cause your program to stall forever. .. code-block:: python def training_step(self, batch, batch_idx): img = ... save_img(img, self.trainer.log_dir) """iflen(self.loggers)>0:ifnotisinstance(self.loggers[0],(TensorBoardLogger,CSVLogger)):dirpath=self.loggers[0].save_direlse:dirpath=self.loggers[0].log_direlse:dirpath=self.default_root_dirdirpath=self.strategy.broadcast(dirpath)returndirpath@propertydefis_global_zero(self)->bool:"""Whether this process is the global zero in multi-node training. .. code-block:: python def training_step(self, batch, batch_idx): if self.trainer.is_global_zero: print("in node 0, accelerator 0") """returnself.strategy.is_global_zero@propertydefdistributed_sampler_kwargs(self)->Optional[Dict[str,Any]]:ifisinstance(self.strategy,ParallelStrategy):returnself.strategy.distributed_sampler_kwargsreturnNone@propertydefenable_validation(self)->bool:"""Check if we should run validation during training."""return(self.fit_loop.epoch_loop.val_loop._data_source.is_defined()andis_overridden("validation_step",self.lightning_module)andself.limit_val_batches>0)@propertydefdefault_root_dir(self)->str:"""The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """if_is_local_file_protocol(self._default_root_dir):returnos.path.normpath(os.path.expanduser(self._default_root_dir))returnself._default_root_dir@propertydefearly_stopping_callback(self)->Optional[EarlyStopping]:"""The first :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist."""callbacks=self.early_stopping_callbacksreturncallbacks[0]iflen(callbacks)>0elseNone@propertydefearly_stopping_callbacks(self)->List[EarlyStopping]:"""A list of all instances of :class:`~lightning.pytorch.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list."""return[cforcinself.callbacksifisinstance(c,EarlyStopping)]@propertydefcheckpoint_callback(self)->Optional[Checkpoint]:"""The first :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist."""callbacks=self.checkpoint_callbacksreturncallbacks[0]iflen(callbacks)>0elseNone@propertydefcheckpoint_callbacks(self)->List[Checkpoint]:"""A list of all instances of :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list."""return[cforcinself.callbacksifisinstance(c,Checkpoint)]@propertydefprogress_bar_callback(self)->Optional[ProgressBar]:"""An instance of :class:`~lightning.pytorch.callbacks.progress.progress_bar.ProgressBar` found in the Trainer.callbacks list, or ``None`` if one doesn't exist."""forcinself.callbacks:ifisinstance(c,ProgressBar):returncreturnNone@propertydefckpt_path(self)->Optional[_PATH]:"""Set to the path/URL of a checkpoint loaded via :meth:`~lightning.pytorch.trainer.trainer.Trainer.fit`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.validate`, :meth:`~lightning.pytorch.trainer.trainer.Trainer.test`, or :meth:`~lightning.pytorch.trainer.trainer.Trainer.predict`. ``None`` otherwise. """returnself._checkpoint_connector._ckpt_path@ckpt_path.setterdefckpt_path(self,ckpt_path:Optional[_PATH])->None:"""Allows you to manage which checkpoint is loaded statefully. .. code-block:: python trainer = Trainer() trainer.ckpt_path = "my/checkpoint/file.ckpt" trainer.fit(model) ... # you will be in charge of resetting this trainer.ckpt_path = None trainer.test(model) """self._checkpoint_connector._ckpt_path=ckpt_pathself._checkpoint_connector._user_managed=bool(ckpt_path)
[docs]defsave_checkpoint(self,filepath:_PATH,weights_only:bool=False,storage_options:Optional[Any]=None)->None:r"""Runs routine to create a checkpoint. This method needs to be called on all processes in case the selected strategy is handling distributed checkpointing. Args: filepath: Path where checkpoint is saved. weights_only: If ``True``, will only save the model weights. storage_options: parameter for how to save to storage, passed to ``CheckpointIO`` plugin Raises: AttributeError: If the model is not attached to the Trainer before calling this method. """ifself.modelisNone:raiseAttributeError("Saving a checkpoint is only possible if a model is attached to the Trainer. Did you call"" `Trainer.save_checkpoint()` before calling `Trainer.{fit,validate,test,predict}`?")checkpoint=self._checkpoint_connector.dump_checkpoint(weights_only)self.strategy.save_checkpoint(checkpoint,filepath,storage_options=storage_options)self.strategy.barrier("Trainer.save_checkpoint")
""" State properties """@propertydefinterrupted(self)->bool:returnself.state.status==TrainerStatus.INTERRUPTED@propertydeftraining(self)->bool:returnself.state.stage==RunningStage.TRAINING@training.setterdeftraining(self,val:bool)->None:ifval:self.state.stage=RunningStage.TRAININGelifself.training:self.state.stage=None@propertydeftesting(self)->bool:returnself.state.stage==RunningStage.TESTING@testing.setterdeftesting(self,val:bool)->None:ifval:self.state.stage=RunningStage.TESTINGelifself.testing:self.state.stage=None@propertydefpredicting(self)->bool:returnself.state.stage==RunningStage.PREDICTING@predicting.setterdefpredicting(self,val:bool)->None:ifval:self.state.stage=RunningStage.PREDICTINGelifself.predicting:self.state.stage=None@propertydefvalidating(self)->bool:returnself.state.stage==RunningStage.VALIDATING@validating.setterdefvalidating(self,val:bool)->None:ifval:self.state.stage=RunningStage.VALIDATINGelifself.validating:self.state.stage=None@propertydefevaluating(self)->bool:returnself.state.stageisnotNoneandself.state.stage.evaluating@propertydefsanity_checking(self)->bool:"""Whether sanity checking is running. Useful to disable some hooks, logging or callbacks during the sanity checking. """returnself.state.stage==RunningStage.SANITY_CHECKING@sanity_checking.setterdefsanity_checking(self,val:bool)->None:ifval:self.state.stage=RunningStage.SANITY_CHECKINGelifself.sanity_checking:self.state.stage=None@propertydefreceived_sigterm(self)->bool:"""Whether a ``signal.SIGTERM`` signal was received. For example, this can be checked to exit gracefully. """returnself._signal_connector.received_sigterm""" Loop properties """@propertydefglobal_step(self)->int:"""The number of optimizer steps taken (does not reset each epoch). This includes multiple optimizers (if enabled). """returnself.fit_loop.epoch_loop.global_step@propertydefcurrent_epoch(self)->int:"""The current epoch, updated after the epoch end hooks are run."""returnself.fit_loop.epoch_progress.current.completed@propertydefmax_epochs(self)->Optional[int]:returnself.fit_loop.max_epochs@propertydefmin_epochs(self)->Optional[int]:returnself.fit_loop.min_epochs@propertydefmax_steps(self)->int:returnself.fit_loop.max_steps@propertydefmin_steps(self)->Optional[int]:returnself.fit_loop.min_steps@propertydefis_last_batch(self)->bool:"""Whether trainer is executing the last batch."""returnself.fit_loop.epoch_loop.batch_progress.is_last_batch@propertydeftrain_dataloader(self)->Optional[TRAIN_DATALOADERS]:"""The training dataloader(s) used during ``trainer.fit()``."""if(combined_loader:=self.fit_loop._combined_loader)isnotNone:returncombined_loader.iterablesreturnNone@propertydefval_dataloaders(self)->Optional[EVAL_DATALOADERS]:"""The validation dataloader(s) used during ``trainer.fit()`` or ``trainer.validate()``."""if(combined_loader:=self.fit_loop.epoch_loop.val_loop._combined_loader)isnotNoneor(combined_loader:=self.validate_loop._combined_loader)isnotNone:returncombined_loader.iterablesreturnNone@propertydeftest_dataloaders(self)->Optional[EVAL_DATALOADERS]:"""The test dataloader(s) used during ``trainer.test()``."""if(combined_loader:=self.test_loop._combined_loader)isnotNone:returncombined_loader.iterablesreturnNone@propertydefpredict_dataloaders(self)->Optional[EVAL_DATALOADERS]:"""The prediction dataloader(s) used during ``trainer.predict()``."""if(combined_loader:=self.predict_loop._combined_loader)isnotNone:returncombined_loader.iterablesreturnNone@propertydefnum_training_batches(self)->Union[int,float]:"""The number of training batches that will be used during ``trainer.fit()``."""returnself.fit_loop.max_batches@propertydefnum_sanity_val_batches(self)->List[Union[int,float]]:"""The number of validation batches that will be used during the sanity-checking part of ``trainer.fit()``."""max_batches=self.fit_loop.epoch_loop.val_loop.max_batches# re-compute the `min` in case this is called outside the sanity-checking stagereturn[min(self.num_sanity_val_steps,batches)forbatchesinmax_batches]@propertydefnum_val_batches(self)->List[Union[int,float]]:"""The number of validation batches that will be used during ``trainer.fit()`` or ``trainer.validate()``."""ifself.state.fn==TrainerFn.VALIDATING:returnself.validate_loop.max_batches# if no trainer.fn is set, assume fit's validation# use the protected access, because it shouldn't return the sanity_val batchesreturnself.fit_loop.epoch_loop.val_loop._max_batches@propertydefnum_test_batches(self)->List[Union[int,float]]:"""The number of test batches that will be used during ``trainer.test()``."""returnself.test_loop.max_batches@propertydefnum_predict_batches(self)->List[Union[int,float]]:"""The number of prediction batches that will be used during ``trainer.predict()``."""returnself.predict_loop.max_batches@propertydef_evaluation_loop(self)->_EvaluationLoop:ifself.state.fn==TrainerFn.FITTING:returnself.fit_loop.epoch_loop.val_loopifself.state.fn==TrainerFn.VALIDATING:returnself.validate_loopifself.state.fn==TrainerFn.TESTING:returnself.test_loopraiseRuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")@propertydef_active_loop(self)->Optional[Union[_FitLoop,_EvaluationLoop,_PredictionLoop]]:ifself.training:returnself.fit_loopifself.sanity_checkingorself.evaluating:returnself._evaluation_loopifself.predicting:returnself.predict_loopreturnNone""" Logging properties """@propertydeflogger(self)->Optional[Logger]:"""The first :class:`~lightning.pytorch.loggers.logger.Logger` being used."""returnself.loggers[0]iflen(self.loggers)>0elseNone@logger.setterdeflogger(self,logger:Optional[Logger])->None:ifnotlogger:self.loggers=[]else:self.loggers=[logger]@propertydefloggers(self)->List[Logger]:"""The list of :class:`~lightning.pytorch.loggers.logger.Logger` used. .. code-block:: python for logger in trainer.loggers: logger.log_metrics({"foo": 1.0}) """returnself._loggers@loggers.setterdefloggers(self,loggers:Optional[List[Logger]])->None:self._loggers=loggersifloggerselse[]@propertydefcallback_metrics(self)->_OUT_DICT:"""The metrics available to callbacks. .. code-block:: python def training_step(self, batch, batch_idx): self.log("a_val", 2.0) callback_metrics = trainer.callback_metrics assert callback_metrics["a_val"] == 2.0 """returnself._logger_connector.callback_metrics@propertydeflogged_metrics(self)->_OUT_DICT:"""The metrics sent to the loggers. This includes metrics logged via :meth:`~lightning.pytorch.core.LightningModule.log` with the :paramref:`~lightning.pytorch.core.LightningModule.log.logger` argument set. """returnself._logger_connector.logged_metrics@propertydefprogress_bar_metrics(self)->_PBAR_DICT:"""The metrics sent to the progress bar. This includes metrics logged via :meth:`~lightning.pytorch.core.LightningModule.log` with the :paramref:`~lightning.pytorch.core.LightningModule.log.prog_bar` argument set. """returnself._logger_connector.progress_bar_metrics@propertydef_results(self)->Optional[_ResultCollection]:active_loop=self._active_loopifactive_loopisnotNone:returnactive_loop._resultsreturnNone""" Other """@propertydefestimated_stepping_batches(self)->Union[int,float]:r"""The estimated number of batches that will ``optimizer.step()`` during training. This accounts for gradient accumulation and the current trainer configuration. This might be used when setting up your training dataloader, if it hasn't been set up already. .. code-block:: python def configure_optimizers(self): optimizer = ... stepping_batches = self.trainer.estimated_stepping_batches scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches) return [optimizer], [scheduler] Raises: MisconfigurationException: If estimated stepping batches cannot be computed due to different `accumulate_grad_batches` at different epochs. """# infinite trainingifself.max_epochs==-1:returnfloat("inf")ifself.max_steps==-1elseself.max_stepsifself.train_dataloaderisNone:rank_zero_info("Loading `train_dataloader` to estimate number of stepping batches.")self.fit_loop.setup_data()total_batches=self.num_training_batches# iterable datasetiftotal_batches==float("inf"):returnself.max_stepsassertself.max_epochsisnotNonemax_estimated_steps=math.ceil(total_batches/self.accumulate_grad_batches)*max(self.max_epochs,1)max_estimated_steps=min(max_estimated_steps,self.max_steps)ifself.max_steps!=-1elsemax_estimated_stepsreturnmax_estimated_steps
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.
You are viewing an outdated version of PyTorch Lightning Docs