# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromdataclassesimportfieldsfromtypingimportAny,Callable,Dict,Generator,List,Optional,Tuple,Unionfromweakrefimportproxyimporttorchfromtorchimportoptimfromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warnfrompytorch_lightning.utilities.typesimport_Stateful,LRSchedulerConfig,LRSchedulerTypeTuple,ReduceLROnPlateaudefdo_nothing_closure()->None:return
[docs]classLightningOptimizer:"""This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches."""def__init__(self,optimizer:Optimizer):# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`self.__dict__={k:vfork,vinoptimizer.__dict__.items()ifknotin("step","__del__")}# For Horovodifhasattr(optimizer,"skip_synchronize"):self.__class__=type("Lightning"+optimizer.__class__.__name__,(self.__class__,optimizer.__class__.__bases__[0]),{})self.skip_synchronize=optimizer.skip_synchronizeself.synchronize=optimizer.synchronizeelse:self.__class__=type("Lightning"+optimizer.__class__.__name__,(self.__class__,optimizer.__class__),{})self._optimizer=optimizerself._strategy:Optional[pl.strategies.Strategy]=Noneself._optimizer_idx=0# to inject logic around the optimizer step, particularly useful with manual optimizationself._on_before_step=do_nothing_closureself._on_after_step=do_nothing_closure@propertydefoptimizer(self)->Optimizer:returnself._optimizer@classmethoddef_to_lightning_optimizer(cls,optimizer:Union[Optimizer,"LightningOptimizer"],strategy:"pl.strategies.Strategy",opt_idx:int)->"LightningOptimizer":ifisinstance(optimizer,LightningOptimizer):# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]lightning_optimizer=optimizerelse:lightning_optimizer=cls(optimizer)lightning_optimizer._strategy=proxy(strategy)lightning_optimizer._optimizer_idx=opt_idxreturnlightning_optimizer
[docs]@contextmanagerdeftoggle_model(self,sync_grad:bool=True)->Generator[None,None,None]:"""This function is just a helper for advanced users. Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """# local import here to avoid circular importfrompytorch_lightning.loops.utilitiesimport_block_parallel_sync_behaviorassertself._strategyisnotNonelightning_module=self._strategy.lightning_moduleassertlightning_moduleisnotNonewith_block_parallel_sync_behavior(self._strategy,block=(notsync_grad)):lightning_module.toggle_optimizer(self,self._optimizer_idx)yieldlightning_module.untoggle_optimizer(self._optimizer_idx)
[docs]defstep(self,closure:Optional[Callable[[],Any]]=None,**kwargs:Any)->Any:"""Performs a single optimization step (parameter update). Args: closure: An optional optimizer closure. kwargs: Any additional arguments to the ``optimizer.step()`` call. Returns: The output from the step call, which is generally the output of the closure execution. Example:: # Scenario for a GAN using manual optimization def training_step(...): opt_gen, opt_dis = self.optimizers() ... # compute generator loss loss_gen = self.compute_generator_loss(...) # zero_grad needs to be called before backward opt_gen.zero_grad() self.manual_backward(loss_gen) opt_gen.step() # compute discriminator loss loss_dis = self.compute_discriminator_loss(...) # zero_grad needs to be called before backward opt_dis.zero_grad() self.manual_backward(loss_dis) opt_dis.step() # A more advanced example def training_step(self, batch, batch_idx, ...): opt_gen, opt_dis = self.optimizers() ... accumulated_grad_batches = batch_idx % 2 == 0 # compute generator loss def closure_gen(): loss_gen = self.compute_generator_loss(...) self.manual_backward(loss_gen) if accumulated_grad_batches: opt_gen.zero_grad() with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): opt_gen.step(closure=closure_gen) def closure_dis(): loss_dis = self.compute_discriminator_loss(...) self.manual_backward(loss_dis) if accumulated_grad_batches: opt_dis.zero_grad() with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """self._on_before_step()ifclosureisNone:closure=do_nothing_closureelifnotcallable(closure):raiseMisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")assertself._strategyisnotNonestep_output=self._strategy.optimizer_step(self._optimizer,self._optimizer_idx,closure,**kwargs)self._on_after_step()returnstep_output
def_init_optimizers_and_lr_schedulers(model:"pl.LightningModule",)->Tuple[List[Optimizer],List[LRSchedulerConfig],List[int]]:"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""assertmodel.trainerisnotNoneoptim_conf=model.trainer._call_lightning_module_hook("configure_optimizers",pl_module=model)ifoptim_confisNone:rank_zero_warn("`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",)optim_conf=_MockOptimizer()optimizers,lr_schedulers,optimizer_frequencies,monitor=_configure_optimizers(optim_conf)lr_scheduler_configs=(_configure_schedulers_automatic_opt(lr_schedulers,monitor)ifmodel.automatic_optimizationelse_configure_schedulers_manual_opt(lr_schedulers))_set_scheduler_opt_idx(optimizers,lr_scheduler_configs)_validate_scheduler_api(lr_scheduler_configs,model)returnoptimizers,lr_scheduler_configs,optimizer_frequenciesdef_configure_optimizers(optim_conf:Union[Dict[str,Any],List,Optimizer,Tuple])->Tuple[List,List,List,Optional[str]]:optimizers,lr_schedulers,optimizer_frequencies=[],[],[]monitor=None# single output, single optimizerifisinstance(optim_conf,Optimizer):optimizers=[optim_conf]# two lists, optimizer + lr schedulerselif(isinstance(optim_conf,(list,tuple))andlen(optim_conf)==2andisinstance(optim_conf[0],list)andall(isinstance(opt,Optimizer)foroptinoptim_conf[0])):opt,sch=optim_confoptimizers=optlr_schedulers=schifisinstance(sch,list)else[sch]# single dictionaryelifisinstance(optim_conf,dict):_validate_optim_conf(optim_conf)optimizers=[optim_conf["optimizer"]]monitor=optim_conf.get("monitor",None)lr_schedulers=[optim_conf["lr_scheduler"]]if"lr_scheduler"inoptim_confelse[]# multiple dictionarieselifisinstance(optim_conf,(list,tuple))andall(isinstance(d,dict)fordinoptim_conf):foropt_dictinoptim_conf:_validate_optim_conf(opt_dict)optimizers=[opt_dict["optimizer"]foropt_dictinoptim_conf]scheduler_dict=(lambdascheduler,opt_idx:dict(scheduler,opt_idx=opt_idx)ifisinstance(scheduler,dict)else{"scheduler":scheduler,"opt_idx":opt_idx})lr_schedulers=[scheduler_dict(opt_dict["lr_scheduler"],opt_idx)foropt_idx,opt_dictinenumerate(optim_conf)if"lr_scheduler"inopt_dict]optimizer_frequencies=[opt_dict["frequency"]foropt_dictinoptim_confifopt_dict.get("frequency",None)isnotNone]# assert that if frequencies are present, they are given for all optimizersifoptimizer_frequenciesandlen(optimizer_frequencies)!=len(optimizers):raiseValueError("A frequency must be given to each optimizer.")# single list or tuple, multiple optimizerelifisinstance(optim_conf,(list,tuple))andall(isinstance(opt,Optimizer)foroptinoptim_conf):optimizers=list(optim_conf)# unknown configurationelse:raiseMisconfigurationException("Unknown configuration for model optimizers."" Output from `model.configure_optimizers()` should be one of:\n"" * `Optimizer`\n"" * [`Optimizer`]\n"" * ([`Optimizer`], [`_LRScheduler`])\n"' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `_LRScheduler`}\n'' * A list of the previously described dict format, with an optional "frequency" key (int)')returnoptimizers,lr_schedulers,optimizer_frequencies,monitordef_configure_schedulers_automatic_opt(schedulers:list,monitor:Optional[str])->List[LRSchedulerConfig]:"""Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization."""lr_scheduler_configs=[]forschedulerinschedulers:ifisinstance(scheduler,dict):# check provided keyssupported_keys={field.nameforfieldinfields(LRSchedulerConfig)}extra_keys=scheduler.keys()-supported_keysifextra_keys:rank_zero_warn(f"Found unsupported keys in the lr scheduler dict: {extra_keys}."" HINT: remove them from the output of `configure_optimizers`.",category=RuntimeWarning,)scheduler={k:vfork,vinscheduler.items()ifkinsupported_keys}if"scheduler"notinscheduler:raiseMisconfigurationException('The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler')if"interval"inschedulerandscheduler["interval"]notin("step","epoch"):raiseMisconfigurationException('The "interval" key in lr scheduler dict must be "step" or "epoch"'f' but is "{scheduler["interval"]}"')scheduler["reduce_on_plateau"]=isinstance(scheduler["scheduler"],optim.lr_scheduler.ReduceLROnPlateau)ifscheduler["reduce_on_plateau"]andscheduler.get("monitor",None)isNone:raiseMisconfigurationException("The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."' For example: {"optimizer": optimizer, "lr_scheduler":'' {"scheduler": scheduler, "monitor": "your_loss"}}')is_one_cycle=isinstance(scheduler["scheduler"],optim.lr_scheduler.OneCycleLR)ifis_one_cycleandscheduler.get("interval","epoch")=="epoch":rank_zero_warn("A `OneCycleLR` scheduler is using 'interval': 'epoch'."" Are you sure you didn't mean 'interval': 'step'?",category=RuntimeWarning,)config=LRSchedulerConfig(**scheduler)elifisinstance(scheduler,ReduceLROnPlateau):ifmonitorisNone:raiseMisconfigurationException("`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"" scheduler is used. For example:"' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}')config=LRSchedulerConfig(scheduler,reduce_on_plateau=True,monitor=monitor)else:config=LRSchedulerConfig(scheduler)lr_scheduler_configs.append(config)returnlr_scheduler_configsdef_configure_schedulers_manual_opt(schedulers:list)->List[LRSchedulerConfig]:"""Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization."""lr_scheduler_configs=[]forschedulerinschedulers:ifisinstance(scheduler,dict):invalid_keys={"interval","frequency","reduce_on_plateau","monitor","strict"}keys_to_warn=[kforkinscheduler.keys()ifkininvalid_keys]ifkeys_to_warn:rank_zero_warn(f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."" You need to call `lr_scheduler.step()` manually in manual optimization.",category=RuntimeWarning,)config=LRSchedulerConfig(**{key:scheduler[key]forkeyinschedulerifkeynotininvalid_keys})else:config=LRSchedulerConfig(scheduler)lr_scheduler_configs.append(config)returnlr_scheduler_configsdef_validate_scheduler_api(lr_scheduler_configs:List[LRSchedulerConfig],model:"pl.LightningModule")->None:forconfiginlr_scheduler_configs:scheduler=config.schedulerifnotisinstance(scheduler,_Stateful):raiseTypeError(f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."" It should have `state_dict` and `load_state_dict` methods defined.")ifnotisinstance(scheduler,LRSchedulerTypeTuple)andnotis_overridden("lr_scheduler_step",model):raiseMisconfigurationException(f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"" API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"" you are using a custom LR scheduler.")def_set_scheduler_opt_idx(optimizers:List[Optimizer],lr_scheduler_configs:List[LRSchedulerConfig])->None:forconfiginlr_scheduler_configs:foropt_idx,optinenumerate(optimizers):ifconfig.scheduler.optimizerisopt:ifconfig.opt_idxisnotNoneandconfig.opt_idx!=opt_idx:raiseMisconfigurationException("`opt_idx` set inside scheduler config does not match with the index"" of the respective optimizer returned from `configure_optimizers`.")config.opt_idx=opt_idxbreakelse:raiseMisconfigurationException("Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`.")def_validate_optim_conf(optim_conf:Dict[str,Any])->None:valid_keys={"optimizer","lr_scheduler","frequency","monitor"}extra_keys=optim_conf.keys()-valid_keysifextra_keys:rank_zero_warn(f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}",category=RuntimeWarning)class_MockOptimizer(Optimizer):"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from `configure_optimizers`."""def__init__(self)->None:super().__init__([torch.zeros(1)],{})defadd_param_group(self,param_group:Dict[Any,Any])->None:pass# Do Nothingdefload_state_dict(self,state_dict:Dict[Any,Any])->None:pass# Do Nothingdefstate_dict(self)->Dict[str,Any]:return{}# Return Emptydefstep(self,closure:Callable=None)->None:ifclosureisnotNone:closure()defzero_grad(self,set_to_none:Optional[bool]=False)->None:pass# Do Nothingdef__repr__(self)->str:return"No Optimizer"
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.