# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Trainer to automate the training."""importinspectimportloggingimportosimporttracebackimportwarningsfromargparseimportArgumentParser,NamespacefromdatetimeimporttimedeltafrompathlibimportPathfromtypingimportAny,Callable,cast,Dict,Iterable,List,Optional,Tuple,Unionfromweakrefimportproxyimporttorchfromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.acceleratorsimportAccelerator,IPUAcceleratorfrompytorch_lightning.callbacksimportCallback,EarlyStopping,ModelCheckpoint,ProgressBarBasefrompytorch_lightning.callbacks.prediction_writerimportBasePredictionWriterfrompytorch_lightning.core.datamoduleimportLightningDataModulefrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.loggersimportLightningLoggerBasefrompytorch_lightning.loggers.baseimportDummyLogger,LoggerCollectionfrompytorch_lightning.loggers.tensorboardimportTensorBoardLoggerfrompytorch_lightning.loopsimportPredictionLoop,TrainingBatchLoop,TrainingEpochLoopfrompytorch_lightning.loops.dataloader.evaluation_loopimportEvaluationLoopfrompytorch_lightning.loops.fit_loopimportFitLoopfrompytorch_lightning.pluginsimportDDPSpawnPlugin,ParallelPlugin,PLUGIN_INPUT,PrecisionPlugin,TrainingTypePluginfrompytorch_lightning.profilerimport(AdvancedProfiler,BaseProfiler,PassThroughProfiler,PyTorchProfiler,SimpleProfiler,XLAProfiler,)frompytorch_lightning.trainer.callback_hookimportTrainerCallbackHookMixinfrompytorch_lightning.trainer.configuration_validatorimportverify_loop_configurationsfrompytorch_lightning.trainer.connectors.accelerator_connectorimportAcceleratorConnectorfrompytorch_lightning.trainer.connectors.callback_connectorimportCallbackConnectorfrompytorch_lightning.trainer.connectors.checkpoint_connectorimportCheckpointConnectorfrompytorch_lightning.trainer.connectors.data_connectorimportDataConnectorfrompytorch_lightning.trainer.connectors.env_vars_connectorimport_defaults_from_env_varsfrompytorch_lightning.trainer.connectors.logger_connectorimportLoggerConnectorfrompytorch_lightning.trainer.connectors.logger_connector.resultimportResultCollectionfrompytorch_lightning.trainer.connectors.signal_connectorimportSignalConnectorfrompytorch_lightning.trainer.data_loadingimportTrainerDataLoadingMixinfrompytorch_lightning.trainer.model_hooksimportTrainerModelHooksMixinfrompytorch_lightning.trainer.optimizersimportTrainerOptimizersMixinfrompytorch_lightning.trainer.statesimportRunningStage,TrainerFn,TrainerState,TrainerStatusfrompytorch_lightning.tuner.auto_gpu_selectimportpick_multiple_gpusfrompytorch_lightning.tuner.lr_finderimport_LRFinderfrompytorch_lightning.tuner.tuningimportTunerfrompytorch_lightning.utilitiesimport(_IPU_AVAILABLE,_TPU_AVAILABLE,device_parser,DeviceType,DistributedType,GradClipAlgorithmType,parsing,rank_zero_deprecation,rank_zero_info,rank_zero_warn,)frompytorch_lightning.utilities.argparseimport(add_argparse_args,from_argparse_args,parse_argparser,parse_env_variables,)frompytorch_lightning.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.utilities.distributedimportdistributed_availablefrompytorch_lightning.utilities.exceptionsimportExitGracefullyException,MisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_fault_tolerant_trainingfrompytorch_lightning.utilities.metaimportmaterialize_modulefrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.seedimportreset_seedfrompytorch_lightning.utilities.typesimport(_EVALUATE_OUTPUT,_PATH,_PREDICT_OUTPUT,EVAL_DATALOADERS,LRSchedulerTypeUnion,TRAIN_DATALOADERS,)log=logging.getLogger(__name__)# warnings to ignore in trainerwarnings.filterwarnings("ignore",message="torch.distributed.reduce_op is deprecated, please use torch.distributed.ReduceOp instead")
[docs]classTrainer(TrainerCallbackHookMixin,TrainerModelHooksMixin,TrainerOptimizersMixin,TrainerDataLoadingMixin,):# Needed because of LightningOptimizer_lightning_optimizers=None
[docs]@_defaults_from_env_varsdef__init__(self,logger:Union[LightningLoggerBase,Iterable[LightningLoggerBase],bool]=True,checkpoint_callback:Optional[bool]=None,enable_checkpointing:bool=True,callbacks:Optional[Union[List[Callback],Callback]]=None,default_root_dir:Optional[str]=None,gradient_clip_val:Optional[Union[int,float]]=None,gradient_clip_algorithm:Optional[str]=None,process_position:int=0,num_nodes:int=1,num_processes:int=1,devices:Optional[Union[List[int],str,int]]=None,gpus:Optional[Union[List[int],str,int]]=None,auto_select_gpus:bool=False,tpu_cores:Optional[Union[List[int],str,int]]=None,ipus:Optional[int]=None,log_gpu_memory:Optional[str]=None,# TODO: Remove in 1.7progress_bar_refresh_rate:Optional[int]=None,# TODO: remove in v1.7enable_progress_bar:bool=True,overfit_batches:Union[int,float]=0.0,track_grad_norm:Union[int,float,str]=-1,check_val_every_n_epoch:int=1,fast_dev_run:Union[int,bool]=False,accumulate_grad_batches:Optional[Union[int,Dict[int,int]]]=None,max_epochs:Optional[int]=None,min_epochs:Optional[int]=None,max_steps:int=-1,min_steps:Optional[int]=None,max_time:Optional[Union[str,timedelta,Dict[str,int]]]=None,limit_train_batches:Union[int,float]=1.0,limit_val_batches:Union[int,float]=1.0,limit_test_batches:Union[int,float]=1.0,limit_predict_batches:Union[int,float]=1.0,val_check_interval:Union[int,float]=1.0,flush_logs_every_n_steps:Optional[int]=None,log_every_n_steps:int=50,accelerator:Optional[Union[str,Accelerator]]=None,strategy:Optional[Union[str,TrainingTypePlugin]]=None,sync_batchnorm:bool=False,precision:Union[int,str]=32,enable_model_summary:bool=True,weights_summary:Optional[str]="top",weights_save_path:Optional[str]=None,num_sanity_val_steps:int=2,resume_from_checkpoint:Optional[Union[Path,str]]=None,profiler:Optional[Union[BaseProfiler,str]]=None,benchmark:bool=False,deterministic:bool=False,reload_dataloaders_every_n_epochs:int=0,reload_dataloaders_every_epoch:bool=False,auto_lr_find:Union[bool,str]=False,replace_sampler_ddp:bool=True,detect_anomaly:bool=False,auto_scale_batch_size:Union[str,bool]=False,prepare_data_per_node:Optional[bool]=None,plugins:Optional[Union[PLUGIN_INPUT,List[PLUGIN_INPUT]]]=None,amp_backend:str="native",amp_level:Optional[str]=None,move_metrics_to_cpu:bool=False,multiple_trainloader_mode:str="max_size_cycle",stochastic_weight_avg:bool=False,terminate_on_nan:Optional[bool]=None,):r""" Customize every aspect of training via flags. Args: accelerator: Supports passing different accelerator types ("cpu", "gpu", "tpu", "ipu", "auto") as well as custom accelerator instances. .. deprecated:: v1.5 Passing training strategies (e.g., 'ddp') to ``accelerator`` has been deprecated in v1.5.0 and will be removed in v1.7.0. Please use the ``strategy`` argument instead. accumulate_grad_batches: Accumulates grads every k batches or as set up in the dict. amp_backend: The mixed precision backend to use ("native" or "apex"). amp_level: The optimization level to use (O1, O2, etc...). By default it will be set to "O2" if ``amp_backend`` is set to "apex". auto_lr_find: If set to True, will make trainer.tune() run a learning rate finder, trying to optimize initial learning for faster convergence. trainer.tune() method will set the suggested learning rate in self.lr or self.learning_rate in the LightningModule. To use a different key set a string instead of True with the key name. auto_scale_batch_size: If set to True, will `initially` run a batch size finder trying to find the largest batch size that fits into memory. The result will be stored in self.batch_size in the LightningModule. Additionally, can be set to either `power` that estimates the batch size through a power search or `binsearch` that estimates the batch size through a binary search. auto_select_gpus: If enabled and ``gpus`` is an integer, pick available gpus automatically. This is especially useful when GPUs are configured to be in "exclusive mode", such that only one process at a time can access them. benchmark: If true enables cudnn.benchmark. callbacks: Add a callback or list of callbacks. checkpoint_callback: If ``True``, enable checkpointing. .. deprecated:: v1.5 ``checkpoint_callback`` has been deprecated in v1.5 and will be removed in v1.7. Please consider using ``enable_checkpointing`` instead. enable_checkpointing: If ``True``, enable checkpointing. It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in :paramref:`~pytorch_lightning.trainer.trainer.Trainer.callbacks`. check_val_every_n_epoch: Check val every n train epochs. default_root_dir: Default path for logs and weights when no logger/ckpt_callback passed. Default: ``os.getcwd()``. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' detect_anomaly: Enable anomaly detection for the autograd engine. deterministic: If ``True``, sets whether PyTorch operations must use deterministic algorithms. Default: ``False``. devices: Will be mapped to either `gpus`, `tpu_cores`, `num_processes` or `ipus`, based on the accelerator type. fast_dev_run: Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es) of train, val and test to find any bugs (ie: a sort of unit test). flush_logs_every_n_steps: How often to flush logs to disk (defaults to every 100 steps). .. deprecated:: v1.5 ``flush_logs_every_n_steps`` has been deprecated in v1.5 and will be removed in v1.7. Please configure flushing directly in the logger instead. gpus: Number of GPUs to train on (int) or which GPUs to train on (list or str) applied per node gradient_clip_val: The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before. gradient_clip_algorithm: The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"`` to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will be set to ``"norm"``. limit_train_batches: How much of training dataset to check (float = fraction, int = num_batches). limit_val_batches: How much of validation dataset to check (float = fraction, int = num_batches). limit_test_batches: How much of test dataset to check (float = fraction, int = num_batches). limit_predict_batches: How much of prediction dataset to check (float = fraction, int = num_batches). logger: Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses the default ``TensorBoardLogger``. ``False`` will disable logging. If multiple loggers are provided and the `save_dir` property of that logger is not set, local files (checkpoints, profiler traces, etc.) are saved in ``default_root_dir`` rather than in the ``log_dir`` of any of the individual loggers. log_gpu_memory: None, 'min_max', 'all'. Might slow performance. .. deprecated:: v1.5 Deprecated in v1.5.0 and will be removed in v1.7.0 Please use the ``DeviceStatsMonitor`` callback directly instead. log_every_n_steps: How often to log within steps (defaults to every 50 steps). prepare_data_per_node: If True, each LOCAL_RANK=0 will call prepare data. Otherwise only NODE_RANK=0, LOCAL_RANK=0 will prepare data .. deprecated:: v1.5 Deprecated in v1.5.0 and will be removed in v1.7.0 Please set ``prepare_data_per_node`` in LightningDataModule or LightningModule directly instead. process_position: Orders the progress bar when running multiple models on same machine. .. deprecated:: v1.5 ``process_position`` has been deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``process_position`` directly to the Trainer's ``callbacks`` argument instead. progress_bar_refresh_rate: How often to refresh progress bar (in steps). Value ``0`` disables progress bar. Ignored when a custom progress bar is passed to :paramref:`~Trainer.callbacks`. Default: None, means a suitable value will be chosen based on the environment (terminal, Google COLAB, etc.). .. deprecated:: v1.5 ``progress_bar_refresh_rate`` has been deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.progress.TQDMProgressBar` with ``refresh_rate`` directly to the Trainer's ``callbacks`` argument instead. To disable the progress bar, pass ``enable_progress_bar = False`` to the Trainer. enable_progress_bar: Whether to enable to progress bar by default. profiler: To profile individual steps during training and assist in identifying bottlenecks. overfit_batches: Overfit a fraction of training data (float) or a set number of batches (int). plugins: Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins. precision: Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16). Can be used on CPU, GPU or TPUs. max_epochs: Stop training once this number of epochs is reached. Disabled by default (None). If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs = -1``. min_epochs: Force training for at least these many epochs. Disabled by default (None). If both min_epochs and min_steps are not specified, defaults to ``min_epochs = 1``. max_steps: Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1`` and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set ``max_epochs`` to ``-1``. min_steps: Force training for at least these number of steps. Disabled by default (None). max_time: Stop training after this amount of time has passed. Disabled by default (None). The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a :class:`datetime.timedelta`, or a dictionary with keys that will be passed to :class:`datetime.timedelta`. num_nodes: Number of GPU nodes for distributed training. num_processes: Number of processes for distributed training with ``accelerator="cpu"``. num_sanity_val_steps: Sanity check runs n validation batches before starting the training routine. Set it to `-1` to run all batches in all validation dataloaders. reload_dataloaders_every_n_epochs: Set to a non-negative integer to reload dataloaders every n epochs. reload_dataloaders_every_epoch: Set to True to reload dataloaders every epoch. .. deprecated:: v1.4 ``reload_dataloaders_every_epoch`` has been deprecated in v1.4 and will be removed in v1.6. Please use ``reload_dataloaders_every_n_epochs``. replace_sampler_ddp: Explicitly enables or disables sampler replacement. If not specified this will toggled automatically when DDP is used. By default it will add ``shuffle=True`` for train sampler and ``shuffle=False`` for val/test sampler. If you want to customize it, you can set ``replace_sampler_ddp=False`` and add your own distributed sampler. resume_from_checkpoint: Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. .. deprecated:: v1.5 ``resume_from_checkpoint`` is deprecated in v1.5 and will be removed in v1.7. Please pass the path to ``Trainer.fit(..., ckpt_path=...)`` instead. strategy: Supports different training strategies with aliases as well custom training type plugins. sync_batchnorm: Synchronize batch norm layers between process groups/whole world. terminate_on_nan: If set to True, will terminate training (by raising a `ValueError`) at the end of each training batch, if any of the parameters or the loss are NaN or +/-inf. .. deprecated:: v1.5 Trainer argument ``terminate_on_nan`` was deprecated in v1.5 and will be removed in 1.7. Please use ``detect_anomaly`` instead. detect_anomaly: Enable anomaly detection for the autograd engine. tpu_cores: How many TPU cores to train on (1 or 8) / Single TPU to train on [1] ipus: How many IPUs to train on. track_grad_norm: -1 no tracking. Otherwise tracks that p-norm. May be set to 'inf' infinity-norm. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before logging them. val_check_interval: How often to check the validation set. Use float to check within a training epoch, use int to check every n steps (batches). enable_model_summary: Whether to enable model summarization by default. weights_summary: Prints a summary of the weights when training begins. .. deprecated:: v1.5 ``weights_summary`` has been deprecated in v1.5 and will be removed in v1.7. To disable the summary, pass ``enable_model_summary = False`` to the Trainer. To customize the summary, pass :class:`~pytorch_lightning.callbacks.model_summary.ModelSummary` directly to the Trainer's ``callbacks`` argument. weights_save_path: Where to save weights if specified. Will override default_root_dir for checkpoints only. Use this if for whatever reason you need the checkpoints stored in a different place than the logs written in `default_root_dir`. Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' Defaults to `default_root_dir`. move_metrics_to_cpu: Whether to force internal logged metrics to be moved to cpu. This can save some gpu memory, but can make training slower. Use with attention. multiple_trainloader_mode: How to loop over the datasets when there are multiple train loaders. In 'max_size_cycle' mode, the trainer ends one epoch when the largest dataset is traversed, and smaller datasets reload when running out of their data. In 'min_size' mode, all the datasets reload when reaching the minimum length of datasets. stochastic_weight_avg: Whether to use `Stochastic Weight Averaging (SWA) <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging/>`_. .. deprecated:: v1.5 ``stochastic_weight_avg`` has been deprecated in v1.5 and will be removed in v1.7. Please pass :class:`~pytorch_lightning.callbacks.stochastic_weight_avg.StochasticWeightAveraging` directly to the Trainer's ``callbacks`` argument instead. """super().__init__()Trainer._log_api_event("init")self.state=TrainerState()gpu_ids,tpu_cores=self._parse_devices(gpus,auto_select_gpus,tpu_cores)# init connectorsself._data_connector=DataConnector(self,multiple_trainloader_mode)self._accelerator_connector=AcceleratorConnector(num_processes,devices,tpu_cores,ipus,accelerator,strategy,gpus,gpu_ids,num_nodes,sync_batchnorm,benchmark,replace_sampler_ddp,deterministic,precision,amp_backend,amp_level,plugins,)self.logger_connector=LoggerConnector(self,log_gpu_memory)self._callback_connector=CallbackConnector(self)self.checkpoint_connector=CheckpointConnector(self,resume_from_checkpoint)self.signal_connector=SignalConnector(self)self.tuner=Tuner(self)fit_loop=FitLoop(min_epochs=(1if(min_epochsisNoneandmin_stepsisNoneandmax_timeisNone)elsemin_epochs),max_epochs=(max_epochsifmax_epochsisnotNoneelse(1000if(max_steps==-1andmax_timeisNone)else-1)),)training_epoch_loop=TrainingEpochLoop(min_steps,max_steps)training_batch_loop=TrainingBatchLoop()training_validation_loop=EvaluationLoop()training_epoch_loop.connect(batch_loop=training_batch_loop,val_loop=training_validation_loop)fit_loop.connect(epoch_loop=training_epoch_loop)# default .fit() loopself.fit_loop=fit_loop# default .validate() loopself.validate_loop=EvaluationLoop()# default .test() loopself.test_loop=EvaluationLoop()# default .predict() loopself.predict_loop=PredictionLoop()# Needed because of LightningOptimizerself._lightning_optimizers=None# .validate() and .test() set this when they load a checkpointself.validated_ckpt_path:Optional[str]=Noneself.tested_ckpt_path:Optional[str]=Noneself.predicted_ckpt_path:Optional[str]=None# todo: remove in v1.7self._weights_summary:Optional[str]=None# init callbacks# Declare attributes to be set in _callback_connector on_trainer_initself._callback_connector.on_trainer_init(callbacks,checkpoint_callback,enable_checkpointing,enable_progress_bar,progress_bar_refresh_rate,process_position,default_root_dir,weights_save_path,enable_model_summary,weights_summary,stochastic_weight_avg,max_time,accumulate_grad_batches,)# hookself.on_init_start()# init optimizer + lr scheduler related flagsself.lr_schedulers=[]self.optimizers=[]self.optimizer_frequencies=[]# init data flagsself._data_connector.on_trainer_init(check_val_every_n_epoch,reload_dataloaders_every_n_epochs,reload_dataloaders_every_epoch,prepare_data_per_node,)ifterminate_on_nanisnotNone:rank_zero_deprecation("Trainer argument `terminate_on_nan` was deprecated in v1.5 and will be removed in 1.7."" Please use `Trainer(detect_anomaly=True)` instead.")ifnotisinstance(terminate_on_nan,bool):raiseTypeError(f"`terminate_on_nan` should be a bool, got {terminate_on_nan}.")# gradient clippingifgradient_clip_valisnotNoneandnotisinstance(gradient_clip_val,(int,float)):raiseTypeError(f"`gradient_clip_val` should be an int or a float. Got {gradient_clip_val}.")ifgradient_clip_algorithmisnotNoneandnotGradClipAlgorithmType.supported_type(gradient_clip_algorithm.lower()):raiseMisconfigurationException(f"`gradient_clip_algorithm` {gradient_clip_algorithm} is invalid. "f"Allowed algorithms: {GradClipAlgorithmType.supported_types()}.")# gradient norm trackingiftrack_grad_norm!=-1andnot((isinstance(track_grad_norm,(int,float))ortrack_grad_norm=="inf")andfloat(track_grad_norm)>0):raiseMisconfigurationException(f"`track_grad_norm` must be a positive number or 'inf' (infinity norm). Got {track_grad_norm}.")self._terminate_on_nan=terminate_on_nanself.gradient_clip_val=gradient_clip_valself.gradient_clip_algorithm=(GradClipAlgorithmType(gradient_clip_algorithm.lower())ifgradient_clip_algorithmisnotNoneelsegradient_clip_algorithm)self.track_grad_norm:float=float(track_grad_norm)self._detect_anomaly:bool=detect_anomalyself._setup_on_init(num_sanity_val_steps)# configure tunerself.tuner.on_trainer_init(auto_lr_find,auto_scale_batch_size)# configure profilerself.__init_profiler(profiler)# init logger flagsself.logger_connector.on_trainer_init(logger,flush_logs_every_n_steps,log_every_n_steps,move_metrics_to_cpu)# init debugging flagsself._init_debugging_flags(limit_train_batches,limit_val_batches,limit_test_batches,limit_predict_batches,val_check_interval,overfit_batches,fast_dev_run,)# Callback systemself.on_init_end()
def_init_debugging_flags(self,limit_train_batches,limit_val_batches,limit_test_batches,limit_predict_batches,val_check_interval,overfit_batches,fast_dev_run,):ifnotisinstance(fast_dev_run,(bool,int)):raiseMisconfigurationException(f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be either a bool or an int >= 0")ifisinstance(fast_dev_run,int)and(fast_dev_run<0):raiseMisconfigurationException(f"fast_dev_run={fast_dev_run} is not a valid configuration. It should be >= 0.")self.fast_dev_run=fast_dev_runfast_dev_run=int(fast_dev_run)# set fast_dev_run=True when it is 1, used while loggingiffast_dev_run==1:self.fast_dev_run=Trueiffast_dev_run:limit_train_batches=fast_dev_runlimit_val_batches=fast_dev_runlimit_test_batches=fast_dev_runlimit_predict_batches=fast_dev_runself.fit_loop.max_steps=fast_dev_runself.num_sanity_val_steps=0self.fit_loop.max_epochs=1val_check_interval=1.0self.check_val_every_n_epoch=1self.logger=DummyLogger()ifself.loggerisnotNoneelseNonerank_zero_info("Running in fast_dev_run mode: will run a full train,"f" val, test and prediction loop using {fast_dev_run} batch(es).")self.limit_train_batches=_determine_batch_limits(limit_train_batches,"limit_train_batches")self.limit_val_batches=_determine_batch_limits(limit_val_batches,"limit_val_batches")self.limit_test_batches=_determine_batch_limits(limit_test_batches,"limit_test_batches")self.limit_predict_batches=_determine_batch_limits(limit_predict_batches,"limit_predict_batches")self.val_check_interval=_determine_batch_limits(val_check_interval,"val_check_interval")self.overfit_batches=_determine_batch_limits(overfit_batches,"overfit_batches")self.determine_data_use_amount(self.overfit_batches)
[docs]defdetermine_data_use_amount(self,overfit_batches:float)->None:"""Use less data for debugging purposes."""ifoverfit_batches>0:self.limit_train_batches=overfit_batchesself.limit_val_batches=overfit_batchesself.limit_test_batches=overfit_batches
def_setup_on_init(self,num_sanity_val_steps:int)->None:self._log_device_info()self.should_stop=Falseself.state=TrainerState()self.num_training_batches=float("inf")self.train_dataloader=Noneifnum_sanity_val_steps==-1:self.num_sanity_val_steps=float("inf")else:self.num_sanity_val_steps=num_sanity_val_stepsself.num_sanity_val_batches=[]self.num_test_batches=[]self.num_val_batches=[]self.test_dataloaders=Noneself.val_dataloaders=None# when true, print evaluation results in .validate() and .test()self.verbose_evaluate=Trueself.num_predict_batches=[]def_call_and_handle_interrupt(self,trainer_fn:Callable,*args:Any,**kwargs:Any)->Any:r""" Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) as all errors should funnel through them Args: trainer_fn: one of (fit, validate, test, predict) *args: positional arguments to be passed to the `trainer_fn` **kwargs: keyword arguments to be passed to `trainer_fn` """try:returntrainer_fn(*args,**kwargs)# TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7exceptKeyboardInterruptasexception:rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")# user could press Ctrl+c many times... only shutdown onceifnotself.interrupted:self.state.status=TrainerStatus.INTERRUPTEDself.on_keyboard_interrupt()self.on_exception(exception)exceptBaseExceptionasexception:self.state.status=TrainerStatus.INTERRUPTEDifdistributed_available()andself.world_size>1:# try syncing remaing processes, kill otherwiseself.training_type_plugin.reconciliate_processes(traceback.format_exc())self._on_exception()# reset bookkeepingself.state.stage=Noneself.on_exception(exception)raise
[docs]deffit(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,train_dataloader=None,# TODO: remove with 1.6ckpt_path:Optional[str]=None,)->None:r""" Runs the full optimization routine. Args: model: Model to fit. train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`. val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. ckpt_path: Path/URL of the checkpoint from which training is resumed. If there is no checkpoint file at the path, an exception is raised. If resuming from mid-epoch checkpoint, training will start from the beginning of the next epoch. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. """iftrain_dataloaderisnotNone:rank_zero_deprecation("`trainer.fit(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.fit(train_dataloaders)` instead. HINT: added 's'")train_dataloaders=train_dataloaderself._call_and_handle_interrupt(self._fit_impl,model,train_dataloaders,val_dataloaders,datamodule,ckpt_path)
def_fit_impl(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,ckpt_path:Optional[str]=None,)->None:Trainer._log_api_event("fit")self.state.fn=TrainerFn.FITTINGself.state.status=TrainerStatus.RUNNINGself.training=True# if a datamodule comes in as the second arg, then fix it for the userifisinstance(train_dataloaders,LightningDataModule):datamodule=train_dataloaderstrain_dataloaders=None# If you supply a datamodule you can't supply train_dataloader or val_dataloadersif(train_dataloadersisnotNoneorval_dataloadersisnotNone)anddatamoduleisnotNone:raiseMisconfigurationException("You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.fit(datamodule=...)`")# links data to the trainerself._data_connector.attach_data(model,train_dataloaders=train_dataloaders,val_dataloaders=val_dataloaders,datamodule=datamodule)# TODO: ckpt_path only in v1.7ckpt_path=ckpt_pathorself.resume_from_checkpointself._run(model,ckpt_path=ckpt_path)assertself.state.stoppedself.training=False
[docs]defvalidate(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,val_dataloaders=None,# TODO: remove with 1.6)->_EVALUATE_OUTPUT:r""" Perform one evaluation epoch over the validation set. Args: model: The model to validate. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying validation samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to validate. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the validation results. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: List of dictionaries with metrics logged during the validation phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_step`, :meth:`~pytorch_lightning.core.lightning.LightningModule.validation_epoch_end`, etc. The length of the list corresponds to the number of validation dataloaders used. """ifval_dataloadersisnotNone:rank_zero_deprecation("`trainer.validate(val_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.validate(dataloaders)` instead.")dataloaders=val_dataloadersreturnself._call_and_handle_interrupt(self._validate_impl,model,dataloaders,ckpt_path,verbose,datamodule)
def_validate_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->_EVALUATE_OUTPUT:# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("validate")self.verbose_evaluate=verboseself.state.fn=TrainerFn.VALIDATINGself.state.status=TrainerStatus.RUNNINGself.validating=True# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply val_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.validate(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.validate()` when it hasn't been passed in a previous run")# links data to the trainerself._data_connector.attach_data(model,val_dataloaders=dataloaders,datamodule=datamodule)self.validated_ckpt_path=self.__set_ckpt_path(ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)# run validateresults=self._run(model,ckpt_path=self.validated_ckpt_path)assertself.state.stoppedself.validating=Falsereturnresults
[docs]deftest(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,test_dataloaders=None,# TODO: remove with 1.6)->_EVALUATE_OUTPUT:r""" Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your test set until you want to. Args: model: The model to test. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying test samples. ckpt_path: Either ``best`` or path to the checkpoint you wish to test. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. verbose: If True, prints the test results. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. Returns: List of dictionaries with metrics logged during the test phase, e.g., in model- or callback hooks like :meth:`~pytorch_lightning.core.lightning.LightningModule.test_step`, :meth:`~pytorch_lightning.core.lightning.LightningModule.test_epoch_end`, etc. The length of the list corresponds to the number of test dataloaders used. """iftest_dataloadersisnotNone:rank_zero_deprecation("`trainer.test(test_dataloaders)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.test(dataloaders)` instead.")dataloaders=test_dataloadersreturnself._call_and_handle_interrupt(self._test_impl,model,dataloaders,ckpt_path,verbose,datamodule)
def_test_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,ckpt_path:Optional[str]=None,verbose:bool=True,datamodule:Optional[LightningDataModule]=None,)->_EVALUATE_OUTPUT:# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("test")self.verbose_evaluate=verboseself.state.fn=TrainerFn.TESTINGself.state.status=TrainerStatus.RUNNINGself.testing=True# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=None# If you supply a datamodule you can't supply test_dataloadersifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.test(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.test()` when it hasn't been passed in a previous run")# links data to the trainerself._data_connector.attach_data(model,test_dataloaders=dataloaders,datamodule=datamodule)self.tested_ckpt_path=self.__set_ckpt_path(ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)# run testresults=self._run(model,ckpt_path=self.tested_ckpt_path)assertself.state.stoppedself.testing=Falsereturnresults
[docs]defpredict(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,datamodule:Optional[LightningDataModule]=None,return_predictions:Optional[bool]=None,ckpt_path:Optional[str]=None,)->Optional[_PREDICT_OUTPUT]:r""" Run inference on your data. This will call the model forward function to compute predictions. Useful to perform distributed and batched predictions. Logging is disabled in the predict hooks. Args: model: The model to predict with. dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them, or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying prediction samples. datamodule: The datamodule with a predict_dataloader method that returns one or more dataloaders. return_predictions: Whether to return predictions. ``True`` by default except when an accelerator that spawns processes is used (not supported). ckpt_path: Either ``best`` or path to the checkpoint you wish to predict. If ``None`` and the model instance was passed, use the current weights. Otherwise, the best model checkpoint from the previous ``trainer.fit`` call will be loaded if a checkpoint callback is configured. Returns: Returns a list of dictionaries, one for each provided dataloader containing their respective predictions. """returnself._call_and_handle_interrupt(self._predict_impl,model,dataloaders,datamodule,return_predictions,ckpt_path)
def_predict_impl(self,model:Optional["pl.LightningModule"]=None,dataloaders:Optional[Union[EVAL_DATALOADERS,LightningDataModule]]=None,datamodule:Optional[LightningDataModule]=None,return_predictions:Optional[bool]=None,ckpt_path:Optional[str]=None,)->Optional[_PREDICT_OUTPUT]:# --------------------# SETUP HOOK# --------------------Trainer._log_api_event("predict")self.state.fn=TrainerFn.PREDICTINGself.state.status=TrainerStatus.RUNNINGself.predicting=Trueself.predict_loop.return_predictions=return_predictions# if a datamodule comes in as the second arg, then fix it for the userifisinstance(dataloaders,LightningDataModule):datamodule=dataloadersdataloaders=NoneifdataloadersisnotNoneanddatamodule:raiseMisconfigurationException("You cannot pass both `trainer.predict(dataloaders=..., datamodule=...)`")model_provided=modelisnotNonemodel=modelorself.lightning_moduleifmodelisNone:raiseMisconfigurationException("`model` must be provided to `trainer.predict()` when it hasn't been passed in a previous run")# links data to the trainerself._data_connector.attach_data(model,predict_dataloaders=dataloaders,datamodule=datamodule)self.predicted_ckpt_path=self.__set_ckpt_path(ckpt_path,model_provided=model_provided,model_connected=self.lightning_moduleisnotNone)results=self._run(model,ckpt_path=self.predicted_ckpt_path)assertself.state.stoppedself.predicting=Falsereturnresults
[docs]deftune(self,model:"pl.LightningModule",train_dataloaders:Optional[Union[TRAIN_DATALOADERS,LightningDataModule]]=None,val_dataloaders:Optional[EVAL_DATALOADERS]=None,datamodule:Optional[LightningDataModule]=None,scale_batch_size_kwargs:Optional[Dict[str,Any]]=None,lr_find_kwargs:Optional[Dict[str,Any]]=None,train_dataloader=None,# TODO: remove with 1.6)->Dict[str,Optional[Union[int,_LRFinder]]]:r""" Runs routines to tune hyperparameters before training. Args: model: Model to tune. train_dataloaders: A collection of :class:`torch.utils.data.DataLoader` or a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` specifying training samples. In the case of multiple dataloaders, please see this :ref:`page <multiple-training-dataloaders>`. val_dataloaders: A :class:`torch.utils.data.DataLoader` or a sequence of them specifying validation samples. datamodule: An instance of :class:`~pytorch_lightning.core.datamodule.LightningDataModule`. scale_batch_size_kwargs: Arguments for :func:`~pytorch_lightning.tuner.batch_size_scaling.scale_batch_size` lr_find_kwargs: Arguments for :func:`~pytorch_lightning.tuner.lr_finder.lr_find` """Trainer._log_api_event("tune")self.state.fn=TrainerFn.TUNINGself.state.status=TrainerStatus.RUNNINGself.tuning=Trueiftrain_dataloaderisnotNone:rank_zero_deprecation("`trainer.tune(train_dataloader)` is deprecated in v1.4 and will be removed in v1.6."" Use `trainer.tune(train_dataloaders)` instead. HINT: added 's'")train_dataloaders=train_dataloader# if a datamodule comes in as the second arg, then fix it for the userifisinstance(train_dataloaders,LightningDataModule):datamodule=train_dataloaderstrain_dataloaders=None# If you supply a datamodule you can't supply train_dataloader or val_dataloadersif(train_dataloadersisnotNoneorval_dataloadersisnotNone)anddatamoduleisnotNone:raiseMisconfigurationException("You cannot pass `train_dataloader` or `val_dataloaders` to `trainer.tune(datamodule=...)`")# links data to the trainerself._data_connector.attach_data(model,train_dataloaders=train_dataloaders,val_dataloaders=val_dataloaders,datamodule=datamodule)result=self.tuner._tune(model,scale_batch_size_kwargs=scale_batch_size_kwargs,lr_find_kwargs=lr_find_kwargs)assertself.state.stoppedself.tuning=Falsereturnresult
def_restore_modules_and_callbacks(self,checkpoint_path:Optional[_PATH]=None)->None:# restore modules after setupself.checkpoint_connector.resume_start(checkpoint_path)self.checkpoint_connector.restore_model()self.checkpoint_connector.restore_datamodule()ifself.state.fn==TrainerFn.FITTING:# restore callback statesself.checkpoint_connector.restore_callbacks()def_run(self,model:"pl.LightningModule",ckpt_path:Optional[str]=None)->Optional[Union[_EVALUATE_OUTPUT,_PREDICT_OUTPUT]]:# clean hparamsifhasattr(model,"hparams"):parsing.clean_namespace(model.hparams)verify_loop_configurations(self,model)# attach model log function to callbackself._callback_connector.attach_model_logging_functions(model)# attach model to the training type pluginself.training_type_plugin.connect(model)# hookself._data_connector.prepare_data()self._callback_connector._attach_model_callbacks()# ----------------------------# SET UP TRAINING# ----------------------------self.call_hook("on_before_accelerator_backend_setup")self.accelerator.setup_environment()self._call_setup_hook()# allow user to setup lightning_module in accelerator environment# check if we should delay restoring checkpoint till laterifnotself.training_type_plugin.restore_checkpoint_after_pre_dispatch:self._restore_modules_and_callbacks(ckpt_path)self._call_configure_sharded_model()# allow user to setup in model sharded environmentself.accelerator.setup(self)# ----------------------------# INSPECT THE CORE LOOPS# ----------------------------fr""" Lightning internal flow looks like this:{Trainer.fit} or {Trainer.test} or {Trainer.predict} || | || create accelerator || | ||{self._dispatch} || | || LIGHTNING{self.training_type_plugin.start_training} || or {self.training_type_plugin.start_evaluating} || or {self.training_type_plugin.start_predicting} || FLOW | ||{self.run_stage} || | || DIRECTION{self._run_train} || or {self._run_evaluate} || or {self._run_predict} || | || results \/ This is used to guide readers to the core loops: train, test, predict.{self._run_predict} is the simplest to understand, use `Go to Definition` to read it :) Search for `start_training` or `start_evaluating` or `start_predicting` in `pytorch_lightning/plugins/training_type_plugin` to find accelerator dispatch functions. """# ----------------------------# TRAIN# ----------------------------# reset logger connectorself.logger_connector.reset_results()self.logger_connector.reset_metrics()# hookifself.state.fn==TrainerFn.FITTING:self.call_hook("on_fit_start")# plugin will setup fitting (e.g. ddp will launch child processes)self._pre_dispatch()ifself.training_type_plugin.restore_checkpoint_after_pre_dispatch:self._restore_modules_and_callbacks(ckpt_path)# restore optimizers, etc.self.checkpoint_connector.restore_training_state()self.checkpoint_connector.resume_end()# dispatch `start_training` or `start_evaluating` or `start_predicting`self._dispatch()# plugin will finalized fitting (e.g. ddp_spawn will load trained model)self._post_dispatch()# ----------------------------# POST-Training CLEAN UP# ----------------------------# hookifself.state.fn==TrainerFn.FITTING:self.call_hook("on_fit_end")# teardown if necessary (similar calls for spawn plugins are excluded as they have# been included at the end of `new_process` functions)ifnotisinstance(self.training_type_plugin,DDPSpawnPlugin):self._call_teardown_hook()ifself.state.status!=TrainerStatus.INTERRUPTED:self.state.status=TrainerStatus.FINISHEDself.state.stage=Nonereturnself.training_type_plugin.resultsdef_pre_dispatch(self):self.accelerator.pre_dispatch(self)self._log_hyperparams()def_log_hyperparams(self)->None:# log hyper-parametershparams_initial=Noneifself.loggerisnotNone:# save exp to get started (this is where the first experiment logs are written)datamodule_log_hyperparams=self.datamodule._log_hyperparamsifself.datamoduleisnotNoneelseFalseifself.lightning_module._log_hyperparamsanddatamodule_log_hyperparams:datamodule_hparams=self.datamodule.hparams_initiallightning_hparams=self.lightning_module.hparams_initialinconsistent_keys=[]forkeyinlightning_hparams.keys()&datamodule_hparams.keys():lm_val,dm_val=lightning_hparams[key],datamodule_hparams[key]iftype(lm_val)!=type(dm_val):inconsistent_keys.append(key)elifisinstance(lm_val,torch.Tensor)andid(lm_val)!=id(dm_val):inconsistent_keys.append(key)eliflm_val!=dm_val:inconsistent_keys.append(key)ifinconsistent_keys:raiseMisconfigurationException(f"Error while merging hparams: the keys {inconsistent_keys} are present ""in both the LightningModule's and LightningDataModule's hparams ""but have different values.")hparams_initial={**lightning_hparams,**datamodule_hparams}elifself.lightning_module._log_hyperparams:hparams_initial=self.lightning_module.hparams_initialelifdatamodule_log_hyperparams:hparams_initial=self.datamodule.hparams_initialifhparams_initialisnotNone:self.logger.log_hyperparams(hparams_initial)self.logger.log_graph(self.lightning_module)self.logger.save()def_post_dispatch(self):self.accelerator.post_dispatch(self)# these `teardown` calls are here instead of in `_call_teardown_hook` since they are internal teardowns# which need to happen before.self.accelerator.teardown()self._data_connector.teardown()self._active_loop.teardown()self.logger_connector.teardown()def_dispatch(self):ifself.evaluating:self.training_type_plugin.start_evaluating(self)elifself.predicting:self.training_type_plugin.start_predicting(self)else:self.training_type_plugin.start_training(self)defrun_stage(self):self.accelerator.dispatch(self)self.__setup_profiler()ifself.evaluating:returnself._run_evaluate()ifself.predicting:returnself._run_predict()returnself._run_train()def_pre_training_routine(self):# wait for all to join if on distributedself.training_type_plugin.barrier("setup_training")# register signalsself.signal_connector.register_signal_handlers()# --------------------------# Pre-train# --------------------------self.call_hook("on_pretrain_routine_start")self.call_hook("on_pretrain_routine_end")def_run_train(self)->None:self._pre_training_routine()ifnotself.is_global_zeroandself.progress_bar_callbackisnotNone:self.progress_bar_callback.disable()self._run_sanity_check(self.lightning_module)# enable train modeself.model.train()torch.set_grad_enabled(True)self.fit_loop.trainer=selfwithtorch.autograd.set_detect_anomaly(self._detect_anomaly):self.fit_loop.run()def_run_evaluate(self)->_EVALUATE_OUTPUT:ifnotself.is_global_zeroandself.progress_bar_callbackisnotNone:self.progress_bar_callback.disable()assertself.evaluating# reload dataloadersself._evaluation_loop._reload_evaluation_dataloaders()# reset trainer on this loop and all child loops in case user connected a custom loopself._evaluation_loop.trainer=selfwithself.profiler.profile(f"run_{self.state.stage}_evaluation"),torch.no_grad():eval_loop_results=self._evaluation_loop.run()# remove the tensors from the eval resultsforresultineval_loop_results:ifisinstance(result,dict):fork,vinresult.items():ifisinstance(v,torch.Tensor):result[k]=v.cpu().item()returneval_loop_resultsdef_run_predict(self)->Optional[_PREDICT_OUTPUT]:self.reset_predict_dataloader(self.lightning_module)# reset trainer on this loop and all child loops in case user connected a custom loopself.predict_loop.trainer=selfwithtorch.no_grad():returnself.predict_loop.run()def_run_sanity_check(self,ref_model):using_val_step=self._data_connector._val_dataloader_source.is_defined()andis_overridden("validation_step",ref_model)should_sanity_check=using_val_stepandself.num_sanity_val_steps>0andself.limit_val_batches>0# run tiny validation (if validation defined)# to make sure program won't crash during valifshould_sanity_check:stage=self.state.stageself.sanity_checking=True# reset logger connectorself.logger_connector.reset_results()self.logger_connector.reset_metrics()self.call_hook("on_sanity_check_start")# reload dataloadersself._evaluation_loop._reload_evaluation_dataloaders()# run eval stepwithtorch.no_grad():self._evaluation_loop.run()self.call_hook("on_sanity_check_end")# reset logger connectorself.logger_connector.reset_results()self.logger_connector.reset_metrics()# reset the seed to what it was before sanity check# prevents sanity check to affect random sampling in trainingreset_seed()# restore the previous stage when the sanity check if finishedself.state.stage=stagedef__set_ckpt_path(self,ckpt_path:Optional[str],model_provided:bool,model_connected:bool)->Optional[str]:ifmodel_providedandckpt_pathisNone:# use passed model to function without loading weightsreturnfn=self.state.fn.valueifmodel_connectedandckpt_pathisNone:rank_zero_warn(f"`.{fn}(ckpt_path=None)` was called without a model."" The best model of the previous `fit` call will be used."f" You can pass `{fn}(ckpt_path='best')` to use and best model"" checkpoint and avoid this warning or"" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model.")ckpt_path="best"ifckpt_path=="best":# if user requests the best checkpoint but we don't have it, errorifnotself.checkpoint_callback:raiseMisconfigurationException(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured.')ifnotself.checkpoint_callback.best_model_path:ifself.fast_dev_run:raiseMisconfigurationException(f"You cannot execute `.{fn}()` with `fast_dev_run=True` unless you do"f" `.{fn}(ckpt_path=PATH)` as no checkpoint path was generated during fitting.")raiseMisconfigurationException(f'`.{fn}(ckpt_path="best")` is set but `ModelCheckpoint` is not configured to save the best model.')# load best weightsckpt_path=self.checkpoint_callback.best_model_pathifnotckpt_path:raiseMisconfigurationException(f"`.{fn}()` found no path for the best weights: {ckpt_path!r}. Please"f" specify a path for a checkpoint `.{fn}(ckpt_path=PATH)`")returnckpt_pathdef_call_setup_hook(self)->None:fn=self.state.fn._setup_fnself.training_type_plugin.barrier("pre_setup")ifself.datamoduleisnotNone:self.datamodule.setup(stage=fn)self.call_hook("setup",stage=fn)self.training_type_plugin.barrier("post_setup")def_call_configure_sharded_model(self)->None:withself.accelerator.model_sharded_context():materialize_module(self.lightning_module)self.call_hook("configure_sharded_model")self.call_hook("on_configure_sharded_model")def_call_teardown_hook(self)->None:fn=self.state.fn._setup_fnifself.datamoduleisnotNone:self.datamodule.teardown(stage=fn)self.call_hook("teardown",stage=fn)self.lightning_module._current_fx_name=Noneself.lightning_module._current_dataloader_idx=None# these could have become stale if metrics are defined in `setup`self.lightning_module._metric_attributes=None# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.# It might be related to xla tensors blocked when moving the cpu kill loggers.ifself.loggerisnotNone:self.logger.finalize("success")# summarize profile resultsself.profiler.describe()defcall_hook(self,hook_name:str,*args:Any,pl_module:Optional["pl.LightningModule"]=None,**kwargs:Any)->Any:pl_module=self.lightning_moduleorpl_moduleifpl_module:prev_fx_name=pl_module._current_fx_namepl_module._current_fx_name=hook_name# always profile hookswithself.profiler.profile(hook_name):# first call trainer hookcallback_fx=getattr(self,hook_name,None)ifcallable(callback_fx):callback_fx(*args,**kwargs)# next call hook in lightningModuleoutput=Nonemodel_fx=getattr(pl_module,hook_name,None)ifcallable(model_fx):output=model_fx(*args,**kwargs)# *Bad code alert*# The `Accelerator` mostly calls the `TrainingTypePlugin` but some of those calls are deprecated.# The following logic selectively chooses which hooks are called on each object.# In the case of `setup` and `teardown`, the hooks on the `LightningModule` should not call the hooks of the# same name in these objects as they are meant to be managed outside of the `LightningModule` lifecycle.# All of this should be fixed by #8506# call the accelerator hookifhook_namein("on_train_start",)andhasattr(self.accelerator,hook_name):accelerator_hook=getattr(self.accelerator,hook_name)accelerator_output=accelerator_hook(*args,**kwargs)# Rely on the accelerator output if lightningModule hook returns nothing# Required for cases such as DataParallel where we reduce the output for the user# todo: move this data parallel logic into the data parallel pluginoutput=accelerator_outputifoutputisNoneelseoutput# call the ttp hookifhook_namenotin("setup","teardown","on_train_start")andhasattr(self.training_type_plugin,hook_name):ttp_hook=getattr(self.training_type_plugin,hook_name)ttp_output=ttp_hook(*args,**kwargs)output=ttp_outputifoutputisNoneelseoutputifpl_module:# restore current_fx when nested contextpl_module._current_fx_name=prev_fx_namereturnoutput@staticmethoddef_parse_devices(gpus:Optional[Union[List[int],str,int]],auto_select_gpus:bool,tpu_cores:Optional[Union[List[int],str,int]],)->Tuple[Optional[List[int]],Optional[Union[List[int],int]]]:ifauto_select_gpusandisinstance(gpus,int):gpus=pick_multiple_gpus(gpus)# TODO (@seannaren, @kaushikb11): Include IPU parsing logic heregpu_ids=device_parser.parse_gpu_ids(gpus)tpu_cores=device_parser.parse_tpu_cores(tpu_cores)returngpu_ids,tpu_cores@staticmethoddef_log_api_event(event:str)->None:torch._C._log_api_usage_once("lightning.trainer."+event)def__init_profiler(self,profiler:Optional[Union[BaseProfiler,str]])->None:ifisinstance(profiler,str):PROFILERS={"simple":SimpleProfiler,"advanced":AdvancedProfiler,"pytorch":PyTorchProfiler,"xla":XLAProfiler,}profiler=profiler.lower()ifprofilernotinPROFILERS:raiseMisconfigurationException("When passing string value for the `profiler` parameter of `Trainer`,"f" it can only be one of {list(PROFILERS.keys())}")profiler_class=PROFILERS[profiler]profiler=profiler_class()self.profiler:BaseProfiler=profilerorPassThroughProfiler()def__setup_profiler(self)->None:local_rank=self.local_rankifself.world_size>1elseNoneself.profiler._lightning_module=proxy(self.lightning_module)self.profiler.setup(stage=self.state.fn._setup_fn,local_rank=local_rank,log_dir=self.log_dir)def_log_device_info(self)->None:rank_zero_info(f"GPU available: {torch.cuda.is_available()}, used: {self._device_type==DeviceType.GPU}")num_tpu_cores=self.tpu_coresifself.tpu_coresisnotNoneandself._device_type==DeviceType.TPUelse0rank_zero_info(f"TPU available: {_TPU_AVAILABLE}, using: {num_tpu_cores} TPU cores")num_ipus=self.ipusifself.ipusisnotNoneelse0rank_zero_info(f"IPU available: {_IPU_AVAILABLE}, using: {num_ipus} IPUs")iftorch.cuda.is_available()andself._device_type!=DeviceType.GPU:rank_zero_warn("GPU available but not used. Set the gpus flag in your trainer `Trainer(gpus=1)` or script `--gpus=1`.")if_TPU_AVAILABLEandself._device_type!=DeviceType.TPU:rank_zero_warn("TPU available but not used. Set the `tpu_cores` flag in your trainer"" `Trainer(tpu_cores=8)` or script `--tpu_cores=8`.")if_IPU_AVAILABLEandself._device_type!=DeviceType.IPUandnotisinstance(self.accelerator,IPUAccelerator):rank_zero_warn("IPU available but not used. Set the `ipus` flag in your trainer"" `Trainer(ipus=8)` or script `--ipus=8`.")def_on_exception(self):ifnot_fault_tolerant_training():return# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.file_path=os.path.join(self.default_root_dir,".pl_auto_save.ckpt")self.save_checkpoint(file_path)""" Accelerator properties """@propertydefaccelerator(self)->Accelerator:returnself._accelerator_connector.accelerator@propertydeftraining_type_plugin(self)->TrainingTypePlugin:returnself.accelerator.training_type_plugin@propertydefprecision_plugin(self)->PrecisionPlugin:returnself.accelerator.precision_plugin@propertydefglobal_rank(self)->int:returnself.training_type_plugin.global_rank@propertydeflocal_rank(self)->int:# some training types define a local rankreturngetattr(self.training_type_plugin,"local_rank",0)@propertydefnode_rank(self)->int:# some training types define a node rankreturngetattr(self.training_type_plugin,"node_rank",0)@propertydefworld_size(self)->int:# some training types define a world sizereturngetattr(self.training_type_plugin,"world_size",1)@propertydefshould_rank_save_checkpoint(self)->bool:returnself.training_type_plugin.should_rank_save_checkpoint@propertydef_distrib_type(self)->DistributedType:returnself._accelerator_connector._distrib_type@propertydef_device_type(self)->DeviceType:returnself._accelerator_connector._device_type@propertydefnum_nodes(self)->int:returnself._accelerator_connector.num_nodes@propertydefnum_processes(self)->int:returnself._accelerator_connector.num_processes@propertydefroot_gpu(self)->Optional[int]:returnself._accelerator_connector.root_gpu@propertydeftpu_cores(self)->int:returnself._accelerator_connector.tpu_cores@propertydefipus(self)->int:returnself._accelerator_connector.num_ipus@propertydefnum_gpus(self)->int:returnself._accelerator_connector.num_gpus@propertydefdevices(self)->Optional[Union[List[int],str,int]]:returnself._accelerator_connector.devices@propertydefdata_parallel_device_ids(self)->Optional[List[int]]:returnself._accelerator_connector.parallel_device_ids@propertydeflightning_module(self)->"pl.LightningModule":returnself.accelerator.lightning_module@propertydefoptimizers(self)->List[Optimizer]:returnself.accelerator.optimizers@optimizers.setterdefoptimizers(self,new_optims:Optional[List[Optimizer]])->None:# Necessary to rewrap optimizers to lightning# They will be re-created when accessing# the `lightning_optimizers` trainer propertyself._lightning_optimizers=Noneself.accelerator.optimizers=new_optims@propertydeflr_schedulers(self)->List[LRSchedulerTypeUnion]:returnself.accelerator.lr_schedulers@lr_schedulers.setterdeflr_schedulers(self,new_schedulers:List[LRSchedulerTypeUnion])->None:self.accelerator.lr_schedulers=new_schedulers@propertydefoptimizer_frequencies(self)->list:returnself.accelerator.optimizer_frequencies@optimizer_frequencies.setterdefoptimizer_frequencies(self,new_freqs:list)->None:self.accelerator.optimizer_frequencies=new_freqs@propertydefamp_backend(self)->Optional[str]:returnself.accelerator.amp_backend@propertydefprecision(self)->Union[str,int]:returnself.accelerator.precision@propertydefscaler(self):returnself.accelerator.scaler@propertydefgpus(self)->Optional[Union[List[int],str,int]]:returnself._accelerator_connector.gpus@propertydefmodel(self)->torch.nn.Module:"""The LightningModule, but possibly wrapped into DataParallel or DistributedDataParallel. To access the pure LightningModule, use :meth:`~pytorch_lightning.trainer.trainer.Trainer.lightning_module` instead. """returnself.accelerator.model@model.setterdefmodel(self,model:torch.nn.Module)->None:"""Setter for the model, pass-through to accelerator and plugin where the model reference is stored. Used by the Tuner to reset the state of Trainer and Accelerator. Args: model: The LightningModule, possibly wrapped into DataParallel or DistributedDataParallel, depending on the backend. """self.accelerator.model=model""" General properties """@propertydeflog_dir(self)->Optional[str]:ifself.loggerisNone:dirpath=self.default_root_direlifisinstance(self.logger,TensorBoardLogger):dirpath=self.logger.log_direlifisinstance(self.logger,LoggerCollection):dirpath=self.default_root_direlse:dirpath=self.logger.save_dirdirpath=self.training_type_plugin.broadcast(dirpath)returndirpath@propertydefuse_amp(self)->bool:returnself.precision==16@propertydefis_global_zero(self)->bool:returnself.global_rank==0@propertydefslurm_job_id(self)->Optional[int]:job_id=os.environ.get("SLURM_JOB_ID")ifjob_id:try:job_id=int(job_id)exceptValueError:job_id=None# in interactive mode, don't make logs use the same job idin_slurm_interactive_mode=os.environ.get("SLURM_JOB_NAME")=="bash"ifin_slurm_interactive_mode:job_id=Nonereturnjob_id@propertydeflightning_optimizers(self)->List[LightningOptimizer]:ifself._lightning_optimizersisNone:self.convert_to_lightning_optimizers()returnself._lightning_optimizers@propertydefdistributed_sampler_kwargs(self)->Optional[dict]:ifisinstance(self.training_type_plugin,ParallelPlugin):returnself.training_type_plugin.distributed_sampler_kwargs@propertydefdata_parallel(self)->bool:returnself._distrib_typein(DistributedType.DP,DistributedType.DDP,DistributedType.DDP_SPAWN,DistributedType.DDP2,)@propertydefprogress_bar_callback(self)->Optional[ProgressBarBase]:returnself._progress_bar_callback@propertydefprogress_bar_dict(self)->dict:"""Read-only for progress bar metrics."""rank_zero_deprecation("`trainer.progress_bar_dict` is deprecated in v1.5 and will be removed in v1.7."" Use `ProgressBarBase.get_metrics` instead.")ref_model=self.lightning_moduleref_model=cast(pl.LightningModule,ref_model)ifself.progress_bar_callback:returnself.progress_bar_callback.get_metrics(self,ref_model)returnself.progress_bar_metrics@propertydef_should_reload_dl_epoch(self)->bool:"""Check if dataloader should be reloaded in the current epoch."""n_epochs=self.reload_dataloaders_every_n_epochsreturnn_epochsand(notself.current_epoch%n_epochs)@propertydefdisable_validation(self)->bool:"""Check if validation is disabled during training."""rank_zero_deprecation("`trainer.disable_validation` is deprecated in v1.4 and will be removed in v1.6."" Use `not trainer.enable_validation` instead.")returnnotself.enable_validation@propertydefenable_validation(self)->bool:"""Check if we should run validation during training."""model_ref=self.lightning_moduleval_loop_enabled=is_overridden("validation_step",model_ref)andself.limit_val_batches>0returnval_loop_enabled@propertydefdefault_root_dir(self)->str:"""The default location to save artifacts of loggers, checkpoints etc. It is used as a fallback if logger or checkpoint callback do not define specific save paths. """ifget_filesystem(self._default_root_dir).protocol=="file":returnos.path.normpath(self._default_root_dir)returnself._default_root_dir@propertydefweights_save_path(self)->str:""" The default root location to save weights (checkpoints), e.g., when the :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` does not define a file path. """ifget_filesystem(self._weights_save_path).protocol=="file":returnos.path.normpath(self._weights_save_path)returnself._weights_save_path@propertydefearly_stopping_callback(self)->Optional[EarlyStopping]:"""The first :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist."""callbacks=self.early_stopping_callbacksreturncallbacks[0]iflen(callbacks)>0elseNone@propertydefearly_stopping_callbacks(self)->List[EarlyStopping]:"""A list of all instances of :class:`~pytorch_lightning.callbacks.early_stopping.EarlyStopping` found in the Trainer.callbacks list."""return[cforcinself.callbacksifisinstance(c,EarlyStopping)]@propertydefprediction_writer_callbacks(self)->List[BasePredictionWriter]:"""A list of all instances of :class:`~pytorch_lightning.callbacks.prediction_writer.BasePredictionWriter` found in the Trainer.callbacks list."""return[cbforcbinself.callbacksifisinstance(cb,BasePredictionWriter)]@propertydefcheckpoint_callback(self)->Optional[ModelCheckpoint]:"""The first :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callback in the Trainer.callbacks list, or ``None`` if it doesn't exist."""callbacks=self.checkpoint_callbacksreturncallbacks[0]iflen(callbacks)>0elseNone@propertydefcheckpoint_callbacks(self)->List[ModelCheckpoint]:"""A list of all instances of :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` found in the Trainer.callbacks list."""return[cforcinself.callbacksifisinstance(c,ModelCheckpoint)]@propertydefresume_from_checkpoint(self)->Optional[Union[str,Path]]:resume_from_checkpoint=self.checkpoint_connector.resume_from_checkpoint_fit_pathifresume_from_checkpointisnotNone:rank_zero_deprecation("`trainer.resume_from_checkpoint` is deprecated in v1.5 and will be removed in v1.7."" Specify the fit checkpoint path with `trainer.fit(ckpt_path=)` instead.")returnresume_from_checkpointdefsave_checkpoint(self,filepath:_PATH,weights_only:bool=False)->None:self.checkpoint_connector.save_checkpoint(filepath,weights_only)""" Parsing properties """@classmethoddefdefault_attributes(cls)->dict:init_signature=inspect.signature(cls)return{k:v.defaultfork,vininit_signature.parameters.items()}
[docs]@classmethoddefget_deprecated_arg_names(cls)->List:"""Returns a list with deprecated Trainer arguments."""depr_arg_names=[]forname,valincls.__dict__.items():ifname.startswith("DEPRECATED")andisinstance(val,(tuple,list)):depr_arg_names.extend(val)returndepr_arg_names
@classmethoddeffrom_argparse_args(cls:Any,args:Union[Namespace,ArgumentParser],**kwargs)->Any:returnfrom_argparse_args(cls,args,**kwargs)@classmethoddefparse_argparser(cls,arg_parser:Union[ArgumentParser,Namespace])->Namespace:returnparse_argparser(cls,arg_parser)@classmethoddefmatch_env_arguments(cls)->Namespace:returnparse_env_variables(cls)@classmethoddefadd_argparse_args(cls,parent_parser:ArgumentParser,**kwargs)->ArgumentParser:returnadd_argparse_args(cls,parent_parser,**kwargs)""" State properties """@propertydefinterrupted(self)->bool:returnself.state.status==TrainerStatus.INTERRUPTED@propertydeftraining(self)->bool:returnself.state.stage==RunningStage.TRAINING@training.setterdeftraining(self,val:bool)->None:ifval:self.state.stage=RunningStage.TRAININGelifself.training:self.state.stage=None@propertydeftesting(self)->bool:returnself.state.stage==RunningStage.TESTING@testing.setterdeftesting(self,val:bool)->None:ifval:self.state.stage=RunningStage.TESTINGelifself.testing:self.state.stage=None@propertydefpredicting(self)->bool:returnself.state.stage==RunningStage.PREDICTING@predicting.setterdefpredicting(self,val:bool)->None:ifval:self.state.stage=RunningStage.PREDICTINGelifself.predicting:self.state.stage=None@propertydeftuning(self)->bool:returnself.state.stage==RunningStage.TUNING@tuning.setterdeftuning(self,val:bool)->None:ifval:self.state.stage=RunningStage.TUNINGelifself.tuning:self.state.stage=None@propertydefvalidating(self)->bool:returnself.state.stage==RunningStage.VALIDATING@validating.setterdefvalidating(self,val:bool)->None:ifval:self.state.stage=RunningStage.VALIDATINGelifself.validating:self.state.stage=None@propertydefevaluating(self)->bool:returnself.state.stageandself.state.stage.evaluating@propertydefsanity_checking(self)->bool:returnself.state.stage==RunningStage.SANITY_CHECKING@sanity_checking.setterdefsanity_checking(self,val:bool)->None:ifval:self.state.stage=RunningStage.SANITY_CHECKINGelifself.sanity_checking:self.state.stage=None""" Loop properties """@propertydefglobal_step(self)->int:returnself.fit_loop.global_step@propertydefcurrent_epoch(self)->int:returnself.fit_loop.current_epoch@propertydefmax_epochs(self)->int:returnself.fit_loop.max_epochs@propertydefmin_epochs(self)->Optional[int]:returnself.fit_loop.min_epochs@propertydefmax_steps(self)->int:returnself.fit_loop.max_steps@propertydefmin_steps(self)->Optional[int]:returnself.fit_loop.min_steps@propertydefis_last_batch(self)->bool:returnself.fit_loop.epoch_loop.batch_progress.is_last_batch@propertydeffit_loop(self)->FitLoop:returnself._fit_loop@fit_loop.setterdeffit_loop(self,loop:FitLoop):"""Attach a custom fit loop to this Trainer. It will run with :meth:`~pytorch_lighting.trainer.trainer.Trainer.fit`. """loop.trainer=selfself._fit_loop=loop@propertydefvalidate_loop(self)->EvaluationLoop:returnself._validate_loop@validate_loop.setterdefvalidate_loop(self,loop:EvaluationLoop):"""Attach a custom validation loop to this Trainer. It will run with :meth:`~pytorch_lighting.trainer.trainer.Trainer.validate`. Note that this loop is different from the one running during training inside the :meth:`pytorch_lightning.trainer.trainer.Trainer.fit` call. """loop.trainer=selfself._validate_loop=loop@propertydeftest_loop(self)->EvaluationLoop:returnself._test_loop@test_loop.setterdeftest_loop(self,loop:EvaluationLoop):"""Attach a custom test loop to this Trainer. It will run with :meth:`~pytorch_lightning.trainer.trainer.Trainer.test`. """loop.trainer=selfself._test_loop=loop@propertydefpredict_loop(self)->PredictionLoop:returnself._predict_loop@predict_loop.setterdefpredict_loop(self,loop:PredictionLoop):"""Attach a custom prediction loop to this Trainer. It will run with :meth:`~pytorch_lightning.trainer.trainer.Trainer.predict`. """loop.trainer=selfself._predict_loop=loop@propertydef_evaluation_loop(self)->EvaluationLoop:ifself.state.fnin(TrainerFn.FITTING,TrainerFn.TUNING):returnself.fit_loop.epoch_loop.val_loopifself.state.fn==TrainerFn.VALIDATING:returnself.validate_loopifself.state.fn==TrainerFn.TESTING:returnself.test_loopraiseRuntimeError("The `Trainer._evaluation_loop` property isn't defined. Accessed outside of scope")@propertydef_active_loop(self)->Optional[Union[FitLoop,EvaluationLoop,PredictionLoop]]:ifself.training:returnself.fit_loopifself.sanity_checkingorself.evaluating:returnself._evaluation_loopifself.predicting:returnself.predict_loop""" Logging properties """@propertydefcallback_metrics(self)->dict:returnself.logger_connector.callback_metrics@propertydeflogged_metrics(self)->dict:returnself.logger_connector.logged_metrics@propertydefprogress_bar_metrics(self)->dict:returnself.logger_connector.progress_bar_metrics@propertydef_results(self)->Optional[ResultCollection]:active_loop=self._active_loopifactive_loopisnotNone:returnactive_loop._resultsdef_exit_gracefully_on_signal(self)->None:if_fault_tolerant_training()andself._terminate_gracefully:caller=inspect.stack()[1]class_name=caller[0].f_locals["self"].__class__.__name__raiseExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")@propertydefweights_summary(self)->Optional[str]:rank_zero_deprecation("`Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")returnself._weights_summary@weights_summary.setterdefweights_summary(self,val:Optional[str])->None:rank_zero_deprecation("Setting `Trainer.weights_summary` is deprecated in v1.5 and will be removed in v1.7.")self._weights_summary=val""" Other """# TODO: refactor this so that it can be done in LightningOptimizerdef__getstate__(self):# remove lightning_optimizersself._lightning_optimizers=Nonereturnself.__dict__def__setstate__(self,state):self.__dict__=state@propertydeftrain_loop(self)->FitLoop:rank_zero_deprecation("`Trainer.train_loop` has been renamed to `Trainer.fit_loop` and will be removed in v1.6.")returnself.fit_loop@propertydefterminate_on_nan(self)->bool:rank_zero_deprecation("`Trainer.terminate_on_nan` is deprecated in v1.5 and will be removed in 1.7.")returnself._terminate_on_nan@terminate_on_nan.setterdefterminate_on_nan(self,val:bool)->None:rank_zero_deprecation(f"Setting `Trainer.terminate_on_nan = {val}` is deprecated in v1.5 and will be removed in 1.7."f" Please set `Trainer(detect_anomaly={val})` instead.")self._terminate_on_nan=val# : 212
def_determine_batch_limits(batches:Union[int,float],name:str)->Union[int,float]:if0<=batches<=1:returnbatchesifbatches>1andbatches%1.0==0:returnint(batches)raiseMisconfigurationException(f"You have passed invalid value {batches} for {name}, it has to be in [0.0, 1.0] or an int.")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.