# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importimportlibimportloggingimportosfromfunctoolsimportwrapsfromtypingimportCallable,Optional,Sequenceimportnumpyasnpimporttorchfromtorch.optimimportOptimizerfromtorch.optim.lr_schedulerimport_LRSchedulerimportpytorch_lightningasplfrompytorch_lightning.callbacksimportCallbackfrompytorch_lightning.loggers.baseimportDummyLoggerfrompytorch_lightning.utilitiesimportrank_zero_warnfrompytorch_lightning.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.parsingimportlightning_hasattr,lightning_setattr# check if ipywidgets is installed before importing tqdm.auto# to ensure it won't fail and a progress bar is displayedifimportlib.util.find_spec("ipywidgets")isnotNone:fromtqdm.autoimporttqdmelse:fromtqdmimporttqdmlog=logging.getLogger(__name__)def_determine_lr_attr_name(trainer:"pl.Trainer",model:"pl.LightningModule")->str:ifisinstance(trainer.auto_lr_find,str):ifnotlightning_hasattr(model,trainer.auto_lr_find):raiseMisconfigurationException(f"`auto_lr_find` was set to {trainer.auto_lr_find}, however"" could not find this as a field in `model` or `model.hparams`.")returntrainer.auto_lr_findattr_options=("lr","learning_rate")forattrinattr_options:iflightning_hasattr(model,attr):returnattrraiseMisconfigurationException("When `auto_lr_find=True`, either `model` or `model.hparams` should"f" have one of these fields: {attr_options} overridden.")class_LRFinder:"""LR finder object. This object stores the results of lr_find(). Args: mode: either `linear` or `exponential`, how to increase lr after each step lr_min: lr to start search from lr_max: lr to stop search num_training: number of steps to take between lr_min and lr_max Example:: # Run lr finder lr_finder = trainer.lr_find(model) # Results stored in lr_finder.results # Plot using lr_finder.plot() # Get suggestion lr = lr_finder.suggestion() """def__init__(self,mode:str,lr_min:float,lr_max:float,num_training:int):assertmodein("linear","exponential"),"mode should be either `linear` or `exponential`"self.mode=modeself.lr_min=lr_minself.lr_max=lr_maxself.num_training=num_trainingself.results={}self._total_batch_idx=0# for debug purposedef_exchange_scheduler(self,configure_optimizers:Callable):"""Decorate configure_optimizers methods such that it returns the users originally specified optimizer together with a new scheduler that that takes care of the learning rate search. """@wraps(configure_optimizers)deffunc():# Decide the structure of the output from configure_optimizers# Same logic as method `init_optimizers` in trainer/optimizers.pyoptim_conf=configure_optimizers()ifisinstance(optim_conf,Optimizer):optimizers=[optim_conf]elifisinstance(optim_conf,(list,tuple))andlen(optim_conf)==2andisinstance(optim_conf[0],list):optimizers,_=optim_confelifisinstance(optim_conf,dict):optimizers=[optim_conf["optimizer"]]elifisinstance(optim_conf,(list,tuple))andisinstance(optim_conf[0],dict):optimizers=[opt_dict["optimizer"]foropt_dictinoptim_conf]elifisinstance(optim_conf,(list,tuple)):optimizers=[optim_conf]iflen(optimizers)!=1:raiseMisconfigurationException(f"`model.configure_optimizers()` returned {len(optimizers)}, but"" learning rate finder only works with single optimizer")optimizer=optimizers[0]new_lrs=[self.lr_min]*len(optimizer.param_groups)forparam_group,new_lrinzip(optimizer.param_groups,new_lrs):param_group["lr"]=new_lrparam_group["initial_lr"]=new_lrargs=(optimizer,self.lr_max,self.num_training)scheduler=_LinearLR(*args)ifself.mode=="linear"else_ExponentialLR(*args)return[optimizer],[{"scheduler":scheduler,"interval":"step"}]returnfuncdefplot(self,suggest:bool=False,show:bool=False):"""Plot results from lr_find run Args: suggest: if True, will mark suggested lr to use with a red point show: if True, will show figure """importmatplotlib.pyplotaspltlrs=self.results["lr"]losses=self.results["loss"]fig,ax=plt.subplots()# Plot loss as a function of the learning rateax.plot(lrs,losses)ifself.mode=="exponential":ax.set_xscale("log")ax.set_xlabel("Learning rate")ax.set_ylabel("Loss")ifsuggest:_=self.suggestion()ifself._optimal_idx:ax.plot(lrs[self._optimal_idx],losses[self._optimal_idx],markersize=10,marker="o",color="red")ifshow:plt.show()returnfigdefsuggestion(self,skip_begin:int=10,skip_end:int=1):"""This will propose a suggestion for choice of initial learning rate as the point with the steepest negative gradient. Returns: lr: suggested initial learning rate to use skip_begin: how many samples to skip in the beginning. Prevent too naive estimates skip_end: how many samples to skip in the end. Prevent too optimistic estimates """try:loss=np.array(self.results["loss"][skip_begin:-skip_end])loss=loss[np.isfinite(loss)]min_grad=np.gradient(loss).argmin()self._optimal_idx=min_grad+skip_beginreturnself.results["lr"][self._optimal_idx]# todo: specify the possible exceptionexceptException:log.exception("Failed to compute suggesting for `lr`. There might not be enough points.")self._optimal_idx=None
[docs]deflr_find(trainer:"pl.Trainer",model:"pl.LightningModule",min_lr:float=1e-8,max_lr:float=1,num_training:int=100,mode:str="exponential",early_stop_threshold:float=4.0,update_attr:bool=False,)->Optional[_LRFinder]:"""See :meth:`~pytorch_lightning.tuner.tuning.Tuner.lr_find`"""iftrainer.fast_dev_run:rank_zero_warn("Skipping learning rate finder since fast_dev_run is enabled.",UserWarning)return# Determine lr attrifupdate_attr:lr_attr_name=_determine_lr_attr_name(trainer,model)save_path=os.path.join(trainer.default_root_dir,"lr_find_temp_model.ckpt")__lr_finder_dump_params(trainer,model)# Prevent going into infinite looptrainer.auto_lr_find=False# Initialize lr finder object (stores results)lr_finder=_LRFinder(mode,min_lr,max_lr,num_training)# Use special lr logger callbacktrainer.callbacks=[_LRCallback(num_training,early_stop_threshold,progress_bar_refresh_rate=1)]# No loggingtrainer.logger=DummyLogger()# Max step set to number of iterationstrainer.fit_loop.max_steps=num_training# Disable standard progress bar for fitiftrainer.progress_bar_callback:trainer.progress_bar_callback.disable()# Required for saving the modeltrainer.optimizers,trainer.lr_schedulers=[],[]trainer.model=model# Dump model checkpointtrainer.save_checkpoint(str(save_path))# Configure optimizer and schedulermodel.configure_optimizers=lr_finder._exchange_scheduler(model.configure_optimizers)# Fit, lr & loss logged in callbacktrainer.tuner._run(model)# Prompt if we stopped earlyiftrainer.global_step!=num_training:log.info(f"LR finder stopped early after {trainer.global_step} steps due to diverging loss.")# Transfer results from callback to lr finder objectlr_finder.results.update({"lr":trainer.callbacks[0].lrs,"loss":trainer.callbacks[0].losses})lr_finder._total_batch_idx=trainer.fit_loop.total_batch_idx# for debug purpose# Reset model stateiftrainer.is_global_zero:trainer.checkpoint_connector.restore(str(save_path))fs=get_filesystem(str(save_path))iffs.exists(save_path):fs.rm(save_path)# Finish by resetting variables so trainer is ready to fit model__lr_finder_restore_params(trainer,model)iftrainer.progress_bar_callback:trainer.progress_bar_callback.enable()# Update lr attr if requiredifupdate_attr:lr=lr_finder.suggestion()# TODO: log lr.results to self.loggerlightning_setattr(model,lr_attr_name,lr)log.info(f"Learning rate set to {lr}")returnlr_finder
def__lr_finder_dump_params(trainer,model):# Prevent going into infinite looptrainer.__dumped_params={"auto_lr_find":trainer.auto_lr_find,"callbacks":trainer.callbacks,"logger":trainer.logger,"max_steps":trainer.max_steps,"checkpoint_callback":trainer.checkpoint_callback,"current_epoch":trainer.current_epoch,"configure_optimizers":model.configure_optimizers,}def__lr_finder_restore_params(trainer,model):trainer.auto_lr_find=trainer.__dumped_params["auto_lr_find"]trainer.logger=trainer.__dumped_params["logger"]trainer.callbacks=trainer.__dumped_params["callbacks"]trainer.fit_loop.max_steps=trainer.__dumped_params["max_steps"]trainer.fit_loop.current_epoch=trainer.__dumped_params["current_epoch"]model.configure_optimizers=trainer.__dumped_params["configure_optimizers"]deltrainer.__dumped_paramsclass_LRCallback(Callback):"""Special callback used by the learning rate finder. This callbacks log the learning rate before each batch and log the corresponding loss after each batch. Args: num_training: number of iterations done by the learning rate finder early_stop_threshold: threshold for stopping the search. If the loss at any point is larger than ``early_stop_threshold*best_loss`` then the search is stopped. To disable, set to ``None``. progress_bar_refresh_rate: rate to refresh the progress bar for the learning rate finder beta: smoothing value, the loss being logged is a running average of loss values logged until now. ``beta`` controls the forget rate i.e. if ``beta=0`` all past information is ignored. """def__init__(self,num_training:int,early_stop_threshold:float=4.0,progress_bar_refresh_rate:int=0,beta:float=0.98,):self.num_training=num_trainingself.early_stop_threshold=early_stop_thresholdself.beta=betaself.losses=[]self.lrs=[]self.avg_loss=0.0self.best_loss=0.0self.progress_bar_refresh_rate=progress_bar_refresh_rateself.progress_bar=Nonedefon_batch_start(self,trainer,pl_module):"""Called before each training batch, logs the lr that will be used"""if(trainer.fit_loop.batch_idx+1)%trainer.accumulate_grad_batches!=0:returnifself.progress_bar_refresh_rateandself.progress_barisNone:self.progress_bar=tqdm(desc="Finding best initial lr",total=self.num_training)self.lrs.append(trainer.lr_schedulers[0]["scheduler"].lr[0])defon_train_batch_end(self,trainer,pl_module,outputs,batch,batch_idx,dataloader_idx):"""Called when the training batch ends, logs the calculated loss"""if(trainer.fit_loop.batch_idx+1)%trainer.accumulate_grad_batches!=0:returnifself.progress_bar:self.progress_bar.update()current_loss=trainer.fit_loop.running_loss.last().item()current_step=trainer.global_step# Avg loss (loss with momentum) + smoothingself.avg_loss=self.beta*self.avg_loss+(1-self.beta)*current_losssmoothed_loss=self.avg_loss/(1-self.beta**(current_step+1))# Check if we divergingifself.early_stop_thresholdisnotNone:ifcurrent_step>1andsmoothed_loss>self.early_stop_threshold*self.best_loss:trainer.fit_loop.max_steps=current_step# stop signalifself.progress_bar:self.progress_bar.close()# Save best loss for diverging checkingifsmoothed_loss<self.best_lossorcurrent_step==1:self.best_loss=smoothed_lossself.losses.append(smoothed_loss)class_LinearLR(_LRScheduler):""" Linearly increases the learning rate between two boundaries over a number of iterations. Args: optimizer: wrapped optimizer. end_lr: the final learning rate. num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. """last_epoch:intbase_lrs:Sequencedef__init__(self,optimizer:torch.optim.Optimizer,end_lr:float,num_iter:int,last_epoch:int=-1):self.end_lr=end_lrself.num_iter=num_itersuper().__init__(optimizer,last_epoch)defget_lr(self):curr_iter=self.last_epoch+1r=curr_iter/self.num_iterifself.last_epoch>0:val=[base_lr+r*(self.end_lr-base_lr)forbase_lrinself.base_lrs]else:val=[base_lrforbase_lrinself.base_lrs]self._lr=valreturnval@propertydeflr(self):returnself._lrclass_ExponentialLR(_LRScheduler):"""Exponentially increases the learning rate between two boundaries over a number of iterations. Arguments: optimizer: wrapped optimizer. end_lr: the final learning rate. num_iter: the number of iterations over which the test occurs. last_epoch: the index of last epoch. Default: -1. """last_epoch:intbase_lrs:Sequencedef__init__(self,optimizer:torch.optim.Optimizer,end_lr:float,num_iter:int,last_epoch:int=-1):self.end_lr=end_lrself.num_iter=num_itersuper().__init__(optimizer,last_epoch)defget_lr(self):curr_iter=self.last_epoch+1r=curr_iter/self.num_iterifself.last_epoch>0:val=[base_lr*(self.end_lr/base_lr)**rforbase_lrinself.base_lrs]else:val=[base_lrforbase_lrinself.base_lrs]self._lr=valreturnval@propertydeflr(self):returnself._lr
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.