# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromdataclassesimportfieldsfromtypingimportAny,Callable,Dict,Generator,List,Optional,Tuple,Unionfromweakrefimportproxyimporttorchfromtorchimportoptimfromtorch.optimimportOptimizerimportlightning.pytorchasplfromlightning.fabric.utilities.typesimport_Stateful,Optimizable,ReduceLROnPlateaufromlightning.pytorch.utilities.exceptionsimportMisconfigurationExceptionfromlightning.pytorch.utilities.model_helpersimportis_overriddenfromlightning.pytorch.utilities.rank_zeroimportrank_zero_warnfromlightning.pytorch.utilities.typesimportLRSchedulerConfig,LRSchedulerTypeTupledefdo_nothing_closure()->None:return
[docs]classLightningOptimizer:"""This class is used to wrap the user optimizers and handle properly the backward and optimizer_step logic across accelerators, AMP, accumulate_grad_batches. Note: The purpose of this wrapper is only to define new methods and redirect the `.step()` call. The internal state ``__dict__`` is not kept in sync with the internal state of the original optimizer, but the Trainer never relies on the internal state of the wrapper. """def__init__(self,optimizer:Optimizer):self.__class__=type("Lightning"+optimizer.__class__.__name__,(self.__class__,optimizer.__class__),{})self._optimizer=optimizerself._strategy:Optional[pl.strategies.Strategy]=None# to inject logic around the optimizer step, particularly useful with manual optimizationself._on_before_step=do_nothing_closureself._on_after_step=do_nothing_closureself.refresh()@propertydefoptimizer(self)->Optimizer:returnself._optimizer
[docs]@contextmanagerdeftoggle_model(self,sync_grad:bool=True)->Generator[None,None,None]:"""This function is just a helper for advanced users. Considering the current optimizer as A and all other optimizers as B. Toggling means all parameters from B exclusive to A will have ``requires_grad`` set to False. When performing gradient accumulation, there is no need to perform grad synchronization during the accumulation phase. Setting `sync_grad` to False will block this synchronization and improve performance. """# local import here to avoid circular importfromlightning.pytorch.loops.utilitiesimport_block_parallel_sync_behaviorassertself._strategyisnotNonelightning_module=self._strategy.lightning_moduleassertlightning_moduleisnotNonewith_block_parallel_sync_behavior(self._strategy,block=(notsync_grad)):lightning_module.toggle_optimizer(self)yieldlightning_module.untoggle_optimizer(self)
[docs]defrefresh(self)->None:"""Refreshes the ``__dict__`` so that it matches the internal states in the wrapped optimizer. This is only needed to present the user with an updated view in case they inspect the state of this wrapper. """# copy most of the `Optimizer` methods into this instance. `__del__` is skipped in case the optimizer has# implemented custom logic which we would not want to call on destruction of the `LightningOptimizer`self.__dict__.update({k:vfork,vinself.optimizer.__dict__.items()ifknotin("step","__del__")})
[docs]defstep(self,closure:Optional[Callable[[],Any]]=None,**kwargs:Any)->Any:"""Performs a single optimization step (parameter update). Args: closure: An optional optimizer closure. kwargs: Any additional arguments to the ``optimizer.step()`` call. Returns: The output from the step call, which is generally the output of the closure execution. Example:: # Scenario for a GAN using manual optimization def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... # compute generator loss loss_gen = self.compute_generator_loss(...) # zero_grad needs to be called before backward opt_gen.zero_grad() self.manual_backward(loss_gen) opt_gen.step() # compute discriminator loss loss_dis = self.compute_discriminator_loss(...) # zero_grad needs to be called before backward opt_dis.zero_grad() self.manual_backward(loss_dis) opt_dis.step() # A more advanced example def training_step(self, batch, batch_idx): opt_gen, opt_dis = self.optimizers() ... accumulated_grad_batches = batch_idx % 2 == 0 # compute generator loss def closure_gen(): loss_gen = self.compute_generator_loss(...) self.manual_backward(loss_gen) if accumulated_grad_batches: opt_gen.zero_grad() with opt_gen.toggle_model(sync_grad=accumulated_grad_batches): opt_gen.step(closure=closure_gen) def closure_dis(): loss_dis = self.compute_discriminator_loss(...) self.manual_backward(loss_dis) if accumulated_grad_batches: opt_dis.zero_grad() with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """self._on_before_step()ifclosureisNone:closure=do_nothing_closureelifnotcallable(closure):raiseMisconfigurationException("When `optimizer.step(closure)` is called, the closure should be callable")assertself._strategyisnotNonestep_output=self._strategy.optimizer_step(self._optimizer,closure,**kwargs)self._on_after_step()returnstep_output
@classmethoddef_to_lightning_optimizer(cls,optimizer:Union[Optimizer,"LightningOptimizer"],strategy:"pl.strategies.Strategy")->"LightningOptimizer":# the user could return a `LightningOptimizer` from `configure_optimizers`, see test:# tests/core/test_lightning_optimizer.py::test_lightning_optimizer[False]lightning_optimizer=optimizerifisinstance(optimizer,LightningOptimizer)elsecls(optimizer)lightning_optimizer._strategy=proxy(strategy)returnlightning_optimizer
def_init_optimizers_and_lr_schedulers(model:"pl.LightningModule",)->Tuple[List[Optimizer],List[LRSchedulerConfig]]:"""Calls `LightningModule.configure_optimizers` and parses and validates the output."""fromlightning.pytorch.trainerimportcalloptim_conf=call._call_lightning_module_hook(model.trainer,"configure_optimizers",pl_module=model)ifoptim_confisNone:rank_zero_warn("`LightningModule.configure_optimizers` returned `None`, this fit will run with no optimizer",)optim_conf=_MockOptimizer()optimizers,lr_schedulers,monitor=_configure_optimizers(optim_conf)lr_scheduler_configs=(_configure_schedulers_automatic_opt(lr_schedulers,monitor)ifmodel.automatic_optimizationelse_configure_schedulers_manual_opt(lr_schedulers))_validate_multiple_optimizers_support(optimizers,model)_validate_optimizers_attached(optimizers,lr_scheduler_configs)_validate_scheduler_api(lr_scheduler_configs,model)returnoptimizers,lr_scheduler_configsdef_configure_optimizers(optim_conf:Union[Dict[str,Any],List,Optimizer,Tuple])->Tuple[List,List,Optional[str]]:optimizers,lr_schedulers=[],[]monitor=None# single output, single optimizerifisinstance(optim_conf,Optimizable):optimizers=[optim_conf]# two lists, optimizer + lr schedulerselif(isinstance(optim_conf,(list,tuple))andlen(optim_conf)==2andisinstance(optim_conf[0],list)andall(isinstance(opt,Optimizable)foroptinoptim_conf[0])):opt,sch=optim_confoptimizers=optlr_schedulers=schifisinstance(sch,list)else[sch]# single dictionaryelifisinstance(optim_conf,dict):_validate_optim_conf(optim_conf)optimizers=[optim_conf["optimizer"]]monitor=optim_conf.get("monitor",None)lr_schedulers=[optim_conf["lr_scheduler"]]if"lr_scheduler"inoptim_confelse[]# multiple dictionarieselifisinstance(optim_conf,(list,tuple))andall(isinstance(d,dict)fordinoptim_conf):foropt_dictinoptim_conf:_validate_optim_conf(opt_dict)optimizers=[opt_dict["optimizer"]foropt_dictinoptim_conf]scheduler_dict=lambdascheduler:dict(scheduler)ifisinstance(scheduler,dict)else{"scheduler":scheduler}lr_schedulers=[scheduler_dict(opt_dict["lr_scheduler"])foropt_dictinoptim_confif"lr_scheduler"inopt_dict]# single list or tuple, multiple optimizerelifisinstance(optim_conf,(list,tuple))andall(isinstance(opt,Optimizable)foroptinoptim_conf):optimizers=list(optim_conf)# unknown configurationelse:raiseMisconfigurationException("Unknown configuration for model optimizers."" Output from `model.configure_optimizers()` should be one of:\n"" * `Optimizer`\n"" * [`Optimizer`]\n"" * ([`Optimizer`], [`LRScheduler`])\n"' * {"optimizer": `Optimizer`, (optional) "lr_scheduler": `LRScheduler`}\n')returnoptimizers,lr_schedulers,monitordef_configure_schedulers_automatic_opt(schedulers:list,monitor:Optional[str])->List[LRSchedulerConfig]:"""Convert each scheduler into `LRSchedulerConfig` with relevant information, when using automatic optimization."""lr_scheduler_configs=[]forschedulerinschedulers:ifisinstance(scheduler,dict):# check provided keyssupported_keys={field.nameforfieldinfields(LRSchedulerConfig)}extra_keys=scheduler.keys()-supported_keysifextra_keys:rank_zero_warn(f"Found unsupported keys in the lr scheduler dict: {extra_keys}."" HINT: remove them from the output of `configure_optimizers`.",category=RuntimeWarning,)scheduler={k:vfork,vinscheduler.items()ifkinsupported_keys}if"scheduler"notinscheduler:raiseMisconfigurationException('The lr scheduler dict must have the key "scheduler" with its item being an lr scheduler')if"interval"inschedulerandscheduler["interval"]notin("step","epoch"):raiseMisconfigurationException('The "interval" key in lr scheduler dict must be "step" or "epoch"'f' but is "{scheduler["interval"]}"')scheduler["reduce_on_plateau"]=scheduler.get("reduce_on_plateau",isinstance(scheduler["scheduler"],optim.lr_scheduler.ReduceLROnPlateau))ifscheduler["reduce_on_plateau"]andscheduler.get("monitor",None)isNone:raiseMisconfigurationException("The lr scheduler dict must include a monitor when a `ReduceLROnPlateau` scheduler is used."' For example: {"optimizer": optimizer, "lr_scheduler":'' {"scheduler": scheduler, "monitor": "your_loss"}}')is_one_cycle=isinstance(scheduler["scheduler"],optim.lr_scheduler.OneCycleLR)ifis_one_cycleandscheduler.get("interval","epoch")=="epoch":rank_zero_warn("A `OneCycleLR` scheduler is using 'interval': 'epoch'."" Are you sure you didn't mean 'interval': 'step'?",category=RuntimeWarning,)config=LRSchedulerConfig(**scheduler)elifisinstance(scheduler,ReduceLROnPlateau):ifmonitorisNone:raiseMisconfigurationException("`configure_optimizers` must include a monitor when a `ReduceLROnPlateau`"" scheduler is used. For example:"' {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"}')config=LRSchedulerConfig(scheduler,reduce_on_plateau=True,monitor=monitor)else:config=LRSchedulerConfig(scheduler)lr_scheduler_configs.append(config)returnlr_scheduler_configsdef_configure_schedulers_manual_opt(schedulers:list)->List[LRSchedulerConfig]:"""Convert each scheduler into `LRSchedulerConfig` structure with relevant information, when using manual optimization."""lr_scheduler_configs=[]forschedulerinschedulers:ifisinstance(scheduler,dict):# interval is not in this list even though the user needs to manually call the scheduler because# the `LearningRateMonitor` callback needs to check its value to know when to log the learning rateinvalid_keys={"reduce_on_plateau","monitor","strict"}keys_to_warn=[kforkinschedulerifkininvalid_keys]ifkeys_to_warn:rank_zero_warn(f"The lr scheduler dict contains the key(s) {keys_to_warn}, but the keys will be ignored."" You need to call `lr_scheduler.step()` manually in manual optimization.",category=RuntimeWarning,)config=LRSchedulerConfig(**{key:scheduler[key]forkeyinschedulerifkeynotininvalid_keys})else:config=LRSchedulerConfig(scheduler)lr_scheduler_configs.append(config)returnlr_scheduler_configsdef_validate_scheduler_api(lr_scheduler_configs:List[LRSchedulerConfig],model:"pl.LightningModule")->None:forconfiginlr_scheduler_configs:scheduler=config.schedulerifnotisinstance(scheduler,_Stateful):raiseTypeError(f"The provided lr scheduler `{scheduler.__class__.__name__}` is invalid."" It should have `state_dict` and `load_state_dict` methods defined.")if(notisinstance(scheduler,LRSchedulerTypeTuple)andnotis_overridden("lr_scheduler_step",model)andmodel.automatic_optimization):raiseMisconfigurationException(f"The provided lr scheduler `{scheduler.__class__.__name__}` doesn't follow PyTorch's LRScheduler"" API. You should override the `LightningModule.lr_scheduler_step` hook with your own logic if"" you are using a custom LR scheduler.")def_validate_multiple_optimizers_support(optimizers:List[Optimizer],model:"pl.LightningModule")->None:ifmodel.automatic_optimizationandlen(optimizers)>1:raiseRuntimeError("Training with multiple optimizers is only supported with manual optimization. Set"" `self.automatic_optimization = False`, then access your optimizers in `training_step` with"" `opt1, opt2, ... = self.optimizers()`.")def_validate_optimizers_attached(optimizers:List[Optimizer],lr_scheduler_configs:List[LRSchedulerConfig])->None:forconfiginlr_scheduler_configs:ifconfig.scheduler.optimizernotinoptimizers:raiseMisconfigurationException("Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`.")def_validate_optim_conf(optim_conf:Dict[str,Any])->None:valid_keys={"optimizer","lr_scheduler","monitor"}extra_keys=optim_conf.keys()-valid_keysifextra_keys:rank_zero_warn(f"Found unsupported keys in the optimizer configuration: {set(extra_keys)}",category=RuntimeWarning)class_MockOptimizer(Optimizer):"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None` is returned from `configure_optimizers`."""def__init__(self)->None:super().__init__([torch.zeros(1)],{})defadd_param_group(self,param_group:Dict[Any,Any])->None:pass# Do Nothingdefload_state_dict(self,state_dict:Dict[Any,Any])->None:pass# Do Nothingdefstate_dict(self)->Dict[str,Any]:return{}# Return Emptydefstep(self,closure:Optional[Callable]=None)->None:ifclosureisnotNone:closure()defzero_grad(self,set_to_none:Optional[bool]=False)->None:pass# Do Nothingdef__repr__(self)->str:return"No Optimizer"
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.