# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Trainer to automate the training."""importloggingimportosimporttracebackimportwarningsfromdatetimeimporttimedeltafrompathlibimportPathfromtypingimportAny,Dict,Iterable,List,Optional,Tuple,Unionfromweakrefimportproxyimporttorchimportpytorch_lightningasplfrompytorch_lightning.acceleratorsimportAccelerator,IPUAcceleratorfrompytorch_lightning.callbacksimportCallbackfrompytorch_lightning.core.datamoduleimportLightningDataModulefrompytorch_lightning.core.memoryimportModelSummaryfrompytorch_lightning.loggersimportLightningLoggerBasefrompytorch_lightning.loopsimportTrainingBatchLoop,TrainingEpochLoopfrompytorch_lightning.loops.dataloader.evaluation_loopimportEvaluationLoopfrompytorch_lightning.loops.dataloader.prediction_loopimportPredictionLoopfrompytorch_lightning.loops.fit_loopimportFitLoopfrompytorch_lightning.pluginsimportPluginfrompytorch_lightning.plugins.environmentsimportClusterEnvironmentfrompytorch_lightning.profilerimport(AdvancedProfiler,BaseProfiler,PassThroughProfiler,PyTorchProfiler,SimpleProfiler,XLAProfiler,)frompytorch_lightning.trainer.callback_hookimportTrainerCallbackHookMixinfrompytorch_lightning.trainer.configuration_validatorimportConfigValidatorfrompytorch_lightning.trainer.connectors.accelerator_connectorimportAcceleratorConnectorfrompytorch_lightning.trainer.connectors.callback_connectorimportCallbackConnectorfrompytorch_lightning.trainer.connectors.checkpoint_connectorimportCheckpointConnectorfrompytorch_lightning.trainer.connectors.data_connectorimportDataConnectorfrompytorch_lightning.trainer.connectors.debugging_connectorimportDebuggingConnectorfrompytorch_lightning.trainer.connectors.env_vars_connectorimport_defaults_from_env_varsfrompytorch_lightning.trainer.connectors.logger_connectorimportLoggerConnectorfrompytorch_lightning.trainer.connectors.model_connectorimportModelConnectorfrompytorch_lightning.trainer.connectors.optimizer_connectorimportOptimizerConnectorfrompytorch_lightning.trainer.connectors.slurm_connectorimportSLURMConnectorfrompytorch_lightning.trainer.connectors.training_trick_connectorimportTrainingTricksConnectorfrompytorch_lightning.trainer.data_loadingimportTrainerDataLoadingMixinfrompytorch_lightning.trainer.deprecated_apiimportDeprecatedTrainerAttributesfrompytorch_lightning.trainer.loggingimportTrainerLoggingMixinfrompytorch_lightning.trainer.model_hooksimportTrainerModelHooksMixinfrompytorch_lightning.trainer.optimizersimportTrainerOptimizersMixinfrompytorch_lightning.trainer.propertiesimportTrainerPropertiesfrompytorch_lightning.trainer.statesimportTrainerFn,TrainerState,TrainerStatusfrompytorch_lightning.trainer.training_tricksimportTrainerTrainingTricksMixinfrompytorch_lightning.tuner.auto_gpu_selectimportpick_multiple_gpusfrompytorch_lightning.tuner.lr_finderimport_LRFinderfrompytorch_lightning.tuner.tuningimportTunerfrompytorch_lightning.utilitiesimport(_IPU_AVAILABLE,_TPU_AVAILABLE,device_parser,DeviceType,parsing,rank_zero_deprecation,rank_zero_info,rank_zero_warn,)frompytorch_lightning.utilities.debuggingimportInternalDebuggerfrompytorch_lightning.utilities.distributedimportdistributed_availablefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_fault_tolerant_enabledfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.seedimportreset_seedfrompytorch_lightning.utilities.typesimport_EVALUATE_OUTPUT,_PREDICT_OUTPUT,EVAL_DATALOADERS,TRAIN_DATALOADERSlog=logging.getLogger(__name__)# warnings to ignore in trainerwarnings.filterwarnings("ignore",message="torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead")
[docs]@_defaults_from_env_varsdef__init__(self,logger:Union[LightningLoggerBase,Iterable[LightningLoggerBase],bool]=True,checkpoint_callback:bool=True,callbacks:Optional[Union[List[Callback],Callback]]=None,default_root_dir:Optional[str]=None,gradient_clip_val:float=0.0,gradient_clip_algorithm:str="norm",process_position:int=0,num_nodes:int=1,num_processes:int=1,devices:Optional[Union[List[int],str,int]]=None,gpus:Optional[Union[List[int],str,int]]=None,auto_select_gpus:bool=False,tpu_cores:Optional[Union[List[int],str,int]]=None,ipus:Optional[int]=None,log_gpu_memory:Optional[str]=None,progress_bar_refresh_rate:Optional[int]=None,overfit_batches:Union[int,float]=0.0,track_grad_norm:Union[int,float,str]=-1,check_val_every_n_epoch:int=1,fast_dev_run:Union[int,bool]=False,accumulate_grad_batches:Union[int,Dict[int,int],List[list]]=1,max_epochs:Optional[int]=None,min_epochs:Optional[int]=None,max_steps:Optional[int]=None,min_steps:Optional[int]=None,max_time:Optional[Union[str,timedelta,Dict[str,int]]]=None,limit_train_batches:Union[int,float]=1.0,limit_val_batches:Union[int,float]=1.0,limit_test_batches:Union[int,float]=1.0,limit_predict_batches:Union[int,float]=1.0,val_check_interval:Union[int,float]=1.0,flush_logs_every_n_steps:int=100,log_every_n_steps:int=50,accelerator:Optional[Union[str,Accelerator]]=None,sync_batchnorm:bool=False,precision:int=32,weights_summary:Optional[str]="top",weights_save_path:Optional[str]=None,num_sanity_val_steps:int=2,truncated_bptt_steps:Optional[int]=None,resume_from_checkpoint:Optional[Union[Path,str]]=None,profiler:Optional[Union[BaseProfiler,str]]=None,benchmark:bool=False,deterministic:bool=False,reload_dataloaders_every_n_epochs:int=0,reload_dataloaders_every_epoch:bool=False,auto_lr_find:Union[bool,str]=False,replace_sampler_ddp:bool=True,terminate_on_nan:bool=False,auto_scale_batch_size:Union[str,bool]=False,prepare_data_per_node:bool=True,plugins:Optional[Union[List[Union[Plugin,ClusterEnvironment,str]],Plugin,ClusterEnvironment,str]]=None,amp_backend:str="native",amp_level:str="O2",distributed_backend:Optional[str]=None,move_metrics_to_cpu:bool=False,multiple_trainloader_mode:str="max_size_cycle",stochastic_weight_avg:bool=False,):r""" Customize every aspect of training via flags Args: accelerator: Previously known as distributed_backend (dp, ddp, ddp2, etc...). Can also take in an accelerator object for custom hardware. accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. amp_backend: The mixed precision backend to use ("native" or "apex") amp_level: The optimization level to use (O1, O2, etc...). auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. auto_select_gpus: If enabled and `gpus` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. benchmark: If true enables cudnn.benchmark. callbacks: Add a callback or list of callbacks. checkpoint_callback: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. check_val_every_n_epoch: Check val every n train epochs. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' deterministic: If true enables cudnn.deterministic. devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`, based on the accelerator type. distributed_backend: deprecated. Please use 'accelerator' fast_dev_run: runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). gpus: number of gpus to train on (int) or which GPUs to train on (list or str) applied per node gradient_clip_val: 0 means don't clip. gradient_clip_algorithm: 'value' means clip_by_value, 'norm' means clip_by_norm. Default: 'norm' limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches) limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches) limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches) limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches) logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are provided and the `save_dir` property of that logger is not set, local files (checkpoints, profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any of the individual loggers. log_gpu_memory: None, 'min_max', 'all'. Might slow performance log_every_n_steps: How often to log within steps (defaults to every 50 steps). prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data process_position: orders the progress bar when running multiple models on same machine. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). profiler: To profile individual steps during training and assist in identifying bottlenecks. overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int). plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. precision: Double precision (64), full precision (32) or half precision (16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs`` = 1000. min_epochs: Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults to ``min_epochs`` = 1. max_steps: Stop training after this number of steps. Disabled by default (None). min_steps: Force training for at least these number of steps. Disabled by default (None). max_time: Stop training after this amount of time has passed. Disabled by default (None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`. num_nodes: number of GPU nodes for distributed training. num_processes: number of processes for distributed training with distributed_backend="ddp_cpu" num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs. Default: 0 reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. .. deprecated:: v1.4 ``reload_dataloaders_every_epoch`` has been deprecated in v1.4 and will be removed in v1.6. Please use ``reload_dataloaders_every_n_epochs``. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, start from scratch. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] ipus: How many IPUs to train on. track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. truncated_bptt_steps: Deprecated in v1.3 to be removed in 1.5. Please use :paramref:`~pytorch_lightning.core.lightning.LightningModule.truncated_bptt_steps` instead. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). weights_summary: Prints a summary of the weights when training begins. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention. multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>_` """super().__init__()Trainer._log_api_event("init")self.state=TrainerState()gpu_ids,tpu_cores=self._parse_devices(gpus,auto_select_gpus,tpu_cores)# init connectorsself.dev_debugger=InternalDebugger(self)self.config_validator=ConfigValidator(self)self.data_connector=DataConnector(self,multiple_trainloader_mode)self.optimizer_connector=OptimizerConnector(self)self.accelerator_connector=AcceleratorConnector(num_processes,devices,tpu_cores,ipus,distributed_backend,accelerator,gpus,gpu_ids,num_nodes,sync_batchnorm,benchmark,replace_sampler_ddp,deterministic,precision,amp_backend,amp_level,plugins,)self.logger_connector=LoggerConnector(self,log_gpu_memory)self.model_connector=ModelConnector(self)self.callback_connector=CallbackConnector(self)self.debugging_connector=DebuggingConnector(self)self.training_tricks_connector=TrainingTricksConnector(self)self.checkpoint_connector=CheckpointConnector(self,resume_from_checkpoint)self.slurm_connector=SLURMConnector(self)self.tuner=Tuner(self)fit_loop=FitLoop(min_epochs=(1if(min_epochsisNoneandmin_stepsisNone)elsemin_epochs),max_epochs=(1000if(max_epochsisNoneandmax_stepsisNone)elsemax_epochs),)training_epoch_loop=TrainingEpochLoop(min_steps,max_steps)training_batch_loop=TrainingBatchLoop()training_validation_loop=EvaluationLoop()training_epoch_loop.connect(batch_loop=training_batch_loop,val_loop=training_validation_loop)fit_loop.connect(epoch_loop=training_epoch_loop)# default .fit() loopself.fit_loop=fit_loop# default .validate() loopself.validate_loop=EvaluationLoop()# default .test() loopself.test_loop=EvaluationLoop()# default .predict() loopself.predict_loop=PredictionLoop()# training stateifweights_summaryisnotNoneandweights_summarynotinModelSummary.MODES:raiseMisconfigurationException(f"`weights_summary` can be None, {', '.join(ModelSummary.MODES)}, but got {weights_summary}")self.weights_summary=weights_summaryself.shown_warnings=set()# init callbacks# Declare attributes to be set in callback_connector on_trainer_initself.callback_connector.on_trainer_init(callbacks,checkpoint_callback,progress_bar_refresh_rate,process_position,default_root_dir,weights_save_path,stochastic_weight_avg,max_time,)# hookself.on_init_start()# init optimizer + lr scheduler related flagsself.optimizer_connector.on_trainer_init()# init data flagsself.data_connector.on_trainer_init(check_val_every_n_epoch,reload_dataloaders_every_n_epochs,reload_dataloaders_every_epoch,prepare_data_per_node,)# init training tricksself.training_tricks_connector.on_trainer_init(gradient_clip_val,gradient_clip_algorithm,track_grad_norm,accumulate_grad_batches,truncated_bptt_steps,terminate_on_nan,)self._setup_on_init(num_sanity_val_steps)# configure tunerself.tuner.on_trainer_init(auto_lr_find,auto_scale_batch_size)# configure profilerself.__init_profiler(profiler)# init logger flagsself.logger_connector.on_trainer_init(logger,flush_logs_every_n_steps,log_every_n_steps,move_metrics_to_cpu)# init debugging flagsself.debugging_connector.on_init_start(limit_train_batches,limit_val_batches,limit_test_batches,limit_predict_batches,val_check_interval,overfit_batches,fast_dev_run,)# Callback systemself.on_init_end()
def_setup_on_init(self,num_sanity_val_steps:int)->None:self._log_device_info()self.should_stop=Falseself.state=TrainerState()self.num_training_batches=0self.train_dataloader=Noneifnum_sanity_val_steps==-1:self.num_sanity_val_steps=float("inf")else:self.num_sanity_val_steps=num_sanity_val_stepsself.num_sanity_val_batches=[]self.num_test_batches=[]self.num_val_batches=[]self.test_dataloaders=Noneself.val_dataloaders=None# .validate() and .test() set this when they load a checkpointself.validated_ckpt_path=Noneself.tested_ckpt_path=None# when true, print evaluation results in .validate() and .test()self.verbose_evaluate=Trueself.num_predict_batches=[]self.predicted_ckpt_path=None
[docs]deffit(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,train_dataloader=None,# noqa TODO: remove with 1.6)->None:r""" Runs the full optimization routine. Args: model: Model to fit. train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`. val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """Trainer._log_api_event("fit")self.state.fn=TrainerFn.FITTINGself.state.status=TrainerStatus.RUNNINGself.training=Trueiftrain_dataloaderisnotNone:rank_zero_deprecation("`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'")train_dataloaders=train_dataloader# if a datamodule comes in as the second arg, then fix it for the userifisinstance(train_dataloaders,LightningDataModule):datamodule=train_dataloaderstrain_dataloaders=None# If you supply a datamodule you can't supply train_dataloader or val_dataloadersif(train_dataloadersisnotNoneorval_dataloadersisnotNone)anddatamoduleisnotNone:raiseMisconfigurationException("You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`")# links data to the trainerself.data_connector.attach_data(model,train_dataloaders=train_dataloaders,val_dataloaders=val_dataloaders,datamodule=datamodule)self.checkpoint_connector.resume_start()self._run(model)assertself.state.stoppedself.training=False
[docs]defvalidate(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]="best",verbose:bool=True,datamodule:Optional[LightningDataModule]=None,val_dataloaders=None,# noqa TODO: remove with 1.6)->_EVALUATE_OUTPUT:r""" Perform one evaluation epoch over the validation set. Args: model: The model to validate. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None``, use the current weights of the model. When the model is given as argument, this parameter will not apply. verbose: If True, prints the validation results. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`, :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("validate")self.verbose_evaluate=verboseself.state.fn=TrainerFn.VALIDATINGself.state.status=TrainerStatus.RUNNINGself.validating=Trueifval_dataloadersisnotNone:rank_zero_deprecation("`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.validate(dataloaders)` instead.")dataloaders=val_dataloaders# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply val_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.validate()` when it hasn't been passed in a previous run")# links data to the trainerself.data_connector.attach_data(model,val_dataloaders=dataloaders,datamodule=datamodule)ifnotmodel_provided:self.validated_ckpt_path=self.__load_ckpt_weights(ckpt_path)# run validateresults=self._run(model)assertself.state.stoppedself.validating=Falsereturnresults
[docs]deftest(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]="best",verbose:bool=True,datamodule:Optional[LightningDataModule]=None,test_dataloaders=None,# noqa TODO: remove with 1.6)->_EVALUATE_OUTPUT:r""" Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. Args: model: The model to test. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None``, use the current weights of the model. When the model is given as argument, this parameter will not apply. verbose: If True, prints the test results. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`, :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("test")self.verbose_evaluate=verboseself.state.fn=TrainerFn.TESTINGself.state.status=TrainerStatus.RUNNINGself.testing=Trueiftest_dataloadersisnotNone:rank_zero_deprecation("`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.test(dataloaders)` instead.")dataloaders=test_dataloaders# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply test_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.test()` when it hasn't been passed in a previous run")# links data to the trainerself.data_connector.attach_data(model,test_dataloaders=dataloaders,datamodule=datamodule)ifnotmodel_provided:self.tested_ckpt_path=self.__load_ckpt_weights(ckpt_path)# run testresults=self._run(model)assertself.state.stoppedself.testing=Falsereturnresults
[docs]defpredict(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,datamodule:Optional[LightningDataModule]=None,return_predictions:Optional[bool]=None,ckpt_path:Optional[str]="best",)->Optional[_PREDICT_OUTPUT]:r""" Separates from fit to make sure you never run on your predictions set until you want to. This will call the model forward function to compute predictions. Args: model: The model to predict with. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). ckpt_path: Either ``best`` or path to the checkpoint you wish to use to predict. If ``None``, use the current weights of the model. When the model is given as argument, this parameter will not apply. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("predict")self.state.fn=TrainerFn.PREDICTINGself.state.status=TrainerStatus.RUNNINGself.predicting=Trueself.predict_loop.return_predictions=return_predictions# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=NoneifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run")# links data to the trainerself.data_connector.attach_data(model,predict_dataloaders=dataloaders,datamodule=datamodule)ifnotmodel_provided:self.predicted_ckpt_path=self.__load_ckpt_weights(ckpt_path)results=self._run(model)assertself.state.stoppedself.predicting=Falsereturnresults
[docs]deftune(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,scale_batch_size_kwargs:Optional[Dict[str,Any]]=None,lr_find_kwargs:Optional[Dict[str,Any]]=None,train_dataloader=None,# noqa TODO: remove with 1.6)->Dict[str,Optional[Union[int,_LRFinder]]]:r""" Runs routines to tune hyperparameters before training. Args: model: Model to tune. train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`. val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` """Trainer._log_api_event("tune")self.state.fn=TrainerFn.TUNINGself.state.status=TrainerStatus.RUNNINGself.tuning=Trueiftrain_dataloaderisnotNone:rank_zero_deprecation("`trainer.tune(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.tune(train_dataloaders)` instead. HINT: added 's'")train_dataloaders=train_dataloader# if a datamodule comes in as the second arg, then fix it for the userifisinstance(train_dataloaders,LightningDataModule):datamodule=train_dataloaderstrain_dataloaders=None# If you supply a datamodule you can't supply train_dataloader or val_dataloadersif(train_dataloadersisnotNoneorval_dataloadersisnotNone)anddatamoduleisnotNone:raiseMisconfigurationException("You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`")# links data to the trainerself.data_connector.attach_data(model,train_dataloaders=train_dataloaders,val_dataloaders=val_dataloaders,datamodule=datamodule)result=self.tuner._tune(model,scale_batch_size_kwargs=scale_batch_size_kwargs,lr_find_kwargs=lr_find_kwargs)assertself.state.stoppedself.tuning=Falsereturnresult
def_run(self,model:"pl.LightningModule")->Optional[Union[_EVALUATE_OUTPUT,_PREDICT_OUTPUT]]:# clean hparamsifhasattr(model,"hparams"):parsing.clean_namespace(model.hparams)self.config_validator.verify_loop_configurations(model)# attach model log function to callbackself.callback_connector.attach_model_logging_functions(model)# hookself.data_connector.prepare_data(model)self.callback_connector._attach_model_callbacks(model,self)# ----------------------------# SET UP TRAINING# ----------------------------self.call_hook("on_before_accelerator_backend_setup",model)self.accelerator.connect(model)self.accelerator.setup_environment()self._call_setup_hook(model)# allow user to setup lightning_module in accelerator environment# restore modules after setupself.checkpoint_connector.restore_datamodule()self.checkpoint_connector.restore_model()# restore callback statesself.checkpoint_connector.restore_callbacks()self._call_configure_sharded_model(model)# allow user to setup in model sharded environmentself.accelerator.setup(self,model)# note: this sets up self.lightning_module# ----------------------------# INSPECT THE CORE LOOPS# ----------------------------fr""" Lightning internal flow looks like this:{Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || create accelerator || | ||{self._dispatch} || | || LIGHTNING{self.accelerator.start_training} || or {self.accelerator.start_evaluating} || or {self.accelerator.start_predicting} || FLOW | ||{self.run_stage} || | || DIRECTION{self._run_train} || or {self._run_evaluate} || or {self._run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict.{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) Search for `start_training` or `start_evaluating` or `start_predicting` in `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """# noqa: W605# ----------------------------# TRAIN# ----------------------------# hookifself.state.fn==TrainerFn.FITTING:self.call_hook("on_fit_start")# plugin will setup fitting (e.g. ddp will launch child processes)self._pre_dispatch()# restore optimizers, etc.self.checkpoint_connector.restore_training_state()# dispatch `start_training` or `start_evaluating` or `start_predicting`self._dispatch()# plugin will finalized fitting (e.g. ddp_spawn will load trained model)self._post_dispatch()# ----------------------------# POST-Training CLEAN UP# ----------------------------# hookifself.state.fn==TrainerFn.FITTING:self.call_hook("on_fit_end")# teardownself._call_teardown_hook(model)ifself.state.status!=TrainerStatus.INTERRUPTED:self.state.status=TrainerStatus.FINISHEDself.state.stage=Nonereturnself.accelerator.resultsdef_pre_dispatch(self):self.accelerator.pre_dispatch(self)self._log_hyperparams()def_log_hyperparams(self):# log hyper-parametershparams_initial=Noneifself.loggerisnotNone:# save exp to get started (this is where the first experiment logs are written)datamodule_log_hyperparams=self.datamodule._log_hyperparamsifself.datamoduleisnotNoneelseFalseifself.lightning_module._log_hyperparamsanddatamodule_log_hyperparams:datamodule_hparams=self.datamodule.hparams_initiallightning_hparams=self.lightning_module.hparams_initialcolliding_keys=lightning_hparams.keys()&datamodule_hparams.keys()ifcolliding_keys:raiseMisconfigurationException(f"Error while merging hparams: the keys {colliding_keys} are present ""in both the LightningModule's and LightningDataModule's hparams.")hparams_initial={**lightning_hparams,**datamodule_hparams}elifself.lightning_module._log_hyperparams:hparams_initial=self.lightning_module.hparams_initialelifdatamodule_log_hyperparams:hparams_initial=self.datamodule.hparams_initialifhparams_initialisnotNone:self.logger.log_hyperparams(hparams_initial)self.logger.log_graph(self.lightning_module)self.logger.save()def_post_dispatch(self):self.accelerator.post_dispatch(self)# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns# which need to happen before.self.accelerator.teardown()self._active_loop.teardown()self.logger_connector.teardown()def_dispatch(self):ifself.evaluating:self.accelerator.start_evaluating(self)elifself.predicting:self.accelerator.start_predicting(self)else:self.accelerator.start_training(self)defrun_stage(self):self.accelerator.dispatch(self)self.__setup_profiler()ifself.evaluating:returnself._run_evaluate()ifself.predicting:returnself._run_predict()returnself._run_train()def_pre_training_routine(self):# wait for all to join if on distributedself.accelerator.barrier("setup_training")# register auto-resubmit when on SLURMself.slurm_connector.register_slurm_signal_handlers()self.checkpoint_connector.resume_end()# --------------------------# Pre-train# --------------------------# on pretrain routine startref_model=self.lightning_moduleself.on_pretrain_routine_start()ref_model.on_pretrain_routine_start()# print model summaryifself.is_global_zeroandself.weights_summaryisnotNoneandnotself.testing:max_depth=ModelSummary.MODES[self.weights_summary]ref_model.summarize(max_depth=max_depth)# on pretrain routine endself.on_pretrain_routine_end()ref_model.on_pretrain_routine_end()def_run_train(self)->None:self._pre_training_routine()ifnotself.is_global_zeroandself.progress_bar_callbackisnotNone:self.progress_bar_callback.disable()self._run_sanity_check(self.lightning_module)# enable train modeself.model.train()torch.set_grad_enabled(True)# reload data when neededmodel=self.lightning_moduleself.reset_train_val_dataloaders(model)try:# reset trainer on this loop and all child loops in case user connected a custom loopself.fit_loop.trainer=selfself.fit_loop.run()exceptKeyboardInterrupt:rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")# user could press Ctrl+c many times... only shutdown onceifnotself.interrupted:self.state.status=TrainerStatus.INTERRUPTEDself.on_keyboard_interrupt()# same treatment as belowself.accelerator.on_train_end()exceptBaseException:self.state.status=TrainerStatus.INTERRUPTEDifdistributed_available()andself.world_size>1:# try syncing remaing processes, kill otherwiseself.training_type_plugin.reconciliate_processes(traceback.format_exc())# give accelerators a chance to finishself.accelerator.on_train_end()self._on_expection()# reset bookkeepingself.state.stage=Noneraisedef_run_evaluate(self)->_EVALUATE_OUTPUT:ifnotself.is_global_zeroandself.progress_bar_callbackisnotNone:self.progress_bar_callback.disable()assertself.evaluating# reload dataloadersself._evaluation_loop.reload_evaluation_dataloaders()# reset trainer on this loop and all child loops in case user connected a custom loopself._evaluation_loop.trainer=selfwithself.profiler.profile(f"run_{self.state.stage}_evaluation"),torch.no_grad():eval_loop_results=self._evaluation_loop.run()# remove the tensors from the eval resultsfori,resultinenumerate(eval_loop_results):ifisinstance(result,dict):fork,vinresult.items():ifisinstance(v,torch.Tensor):result[k]=v.cpu().item()returneval_loop_resultsdef_run_predict(self)->Optional[_PREDICT_OUTPUT]:self.reset_predict_dataloader(self.lightning_module)# reset trainer on this loop and all child loops in case user connected a custom loopself.predict_loop.trainer=selfwithtorch.no_grad():returnself.predict_loop.run()def_run_sanity_check(self,ref_model):using_val_step=ref_model.val_dataloaderisnotNoneandis_overridden("validation_step",ref_model)should_sanity_check=using_val_stepandself.num_sanity_val_steps>0andself.limit_val_batches>0# run tiny validation (if validation defined)# to make sure program won't crash during valifshould_sanity_check:stage=self.state.stageself.sanity_checking=True# hook and callbackself.on_sanity_check_start()# reload dataloadersself._evaluation_loop.reload_evaluation_dataloaders()# run eval stepwithtorch.no_grad():self._evaluation_loop.run()self.on_sanity_check_end()# reset validation metricsself.logger_connector.reset()# reset the seed to what it was before sanity check# prevents sanity check to affect random sampling in trainingreset_seed()# restore the previous stage when the sanity check if finishedself.state.stage=stagedef__load_ckpt_weights(self,ckpt_path:Optional[str])->Optional[str]:ifckpt_pathisNone:returnfn=self.state.fn.valueifckpt_path=="best":# if user requests the best checkpoint but we don't have it, errorifnotself.checkpoint_callback.best_model_path:ifself.fast_dev_run:raiseMisconfigurationException(f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.")raiseMisconfigurationException(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.')# load best weightsckpt_path=self.checkpoint_callback.best_model_pathifnotckpt_path:raiseMisconfigurationException(f'`.{fn}()` found no path for the best weights: "{ckpt_path}". Please'f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`")# only one process running at this point for TPUs, as spawn isn't triggered yet# todo: move this logic internally within the barrier.ifnotself._device_type==DeviceType.TPU:self.training_type_plugin.barrier()self.checkpoint_connector.restore_model_weights(ckpt_path)returnckpt_pathdef_call_setup_hook(self,model:"pl.LightningModule")->None:fn=self.state.fn._setup_fnself.accelerator.barrier("pre_setup")ifself.datamoduleisnotNone:self.datamodule.setup(stage=fn)self.setup(model,stage=fn)model.setup(stage=fn)self.accelerator.barrier("post_setup")def_call_configure_sharded_model(self,model:"pl.LightningModule")->None:# Call configure sharded model hook if accelerator requests. In some cases# we will not call the hook; the hook has initialized the sharded model for example.# used on the model if the user re-create a trainer with resume_from_checkpointmodel_call_configure_sharded_model_hook=getattr(model,"call_configure_sharded_model_hook",False)ifself.accelerator.call_configure_sharded_model_hookandnotmodel_call_configure_sharded_model_hook:withself.accelerator.model_sharded_context():model.configure_sharded_model()self.configure_sharded_model(model)model.call_configure_sharded_model_hook=Trueself.accelerator.call_configure_sharded_model_hook=Falsedef_call_teardown_hook(self,model:"pl.LightningModule")->None:fn=self.state.fn._setup_fnifself.datamoduleisnotNone:self.datamodule.teardown(stage=fn)self.profiler.teardown(stage=fn)self.data_connector.detach_data(self.lightning_module)self.teardown(stage=fn)model.teardown(stage=fn)model._current_fx_name=Nonemodel._current_dataloader_idx=None# these could have become stale if metrics are defined in `setup`model._metric_attributes=Nonedefcall_hook(self,hook_name:str,*args,**kwargs)->Any:# Note this implementation is copy/pasted into the TrainLoop class in TrainingEpochLoop._on_train_epoch_end_hook# This was done to manage the deprecation of the `outputs` argument to on_train_epoch_end# If making changes to this function, ensure that those changes are also made to# TrainingEpochLoop._on_train_epoch_end_hookifself.lightning_module:prev_fx_name=self.lightning_module._current_fx_nameself.lightning_module._current_fx_name=hook_name# always profile hookswithself.profiler.profile(hook_name):# first call trainer hookifhasattr(self,hook_name):trainer_hook=getattr(self,hook_name)trainer_hook(*args,**kwargs)# next call hook in lightningModuleoutput=Nonemodel_ref=self.lightning_moduleifis_overridden(hook_name,model_ref):hook_fx=getattr(model_ref,hook_name)output=hook_fx(*args,**kwargs)# call the accelerator hookifhasattr(self.accelerator,hook_name):accelerator_hook=getattr(self.accelerator,hook_name)accelerator_output=accelerator_hook(*args,**kwargs)# Rely on the accelerator output if lightningModule hook returns nothing# Required for cases such as DataParallel where we reduce the output for the user# todo: move this data parallel logic into the data parallel pluginoutput=accelerator_outputifoutputisNoneelseoutputifself.lightning_module:# restore current_fx when nested contextself.lightning_module._current_fx_name=prev_fx_namereturnoutputdef_parse_devices(self,gpus:Optional[Union[List[int],str,int]],auto_select_gpus:bool,tpu_cores:Optional[Union[List[int],str,int]],)->Tuple[Optional[List[int]],Optional[Union[List[int],int]]]:ifauto_select_gpusandisinstance(gpus,int):gpus=pick_multiple_gpus(gpus)# TODO (@seannaren, @kaushikb11): Include IPU parsing logic heregpu_ids=device_parser.parse_gpu_ids(gpus)tpu_cores=device_parser.parse_tpu_cores(tpu_cores)returngpu_ids,tpu_cores@staticmethoddef_log_api_event(event:str)->None:torch._C._log_api_usage_once("lightning.trainer."+event)def__init_profiler(self,profiler:Optional[Union[BaseProfiler,str]])->None:ifisinstance(profiler,str):PROFILERS={"simple":SimpleProfiler,"advanced":AdvancedProfiler,"pytorch":PyTorchProfiler,"xla":XLAProfiler,}profiler=profiler.lower()ifprofilernotinPROFILERS:raiseMisconfigurationException("When passing string value for the `profiler` parameter of `Trainer`,"f" it can only be one of {list(PROFILERS.keys())}")profiler_class=PROFILERS[profiler]profiler=profiler_class()self.profiler:BaseProfiler=profilerorPassThroughProfiler()def__setup_profiler(self)->None:local_rank=self.local_rankifself.world_size>1elseNoneself.profiler._lightning_module=proxy(self.lightning_module)self.profiler.setup(stage=self.state.fn._setup_fn,local_rank=local_rank,log_dir=self.log_dir)def_log_device_info(self)->None:rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type==DeviceType.GPU}")num_tpu_cores=self.tpu_coresifself.tpu_coresisnotNoneandself._device_type==DeviceType.TPUelse0rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")num_ipus=self.ipusifself.ipusisnotNoneelse0rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")iftorch.cuda.is_available()andself._device_type!=DeviceType.GPU:rank_zero_warn("GPU available but not used. Set the gpus flag in your trainer"" `Trainer(gpus=1)` or script `--gpus=1`.")if_TPU_AVAILABLEandself._device_type!=DeviceType.TPU:rank_zero_warn("TPU available but not used. Set the `tpu_cores` flag in your trainer"" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`.")if_IPU_AVAILABLEandself._device_type!=DeviceType.IPUandnotisinstance(self.accelerator,IPUAccelerator):rank_zero_warn("IPU available but not used. Set the `ipus` flag in your trainer"" `Trainer(ipus=8)` or script `--ipus=8`.")def_on_expection(self):ifnotself.is_global_zeroornot_fault_tolerant_enabled():return# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.file_path=os.path.join(self.default_root_dir,".pl_auto_save.ckpt")self.save_checkpoint(file_path)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.