Source code for pytorch_lightning.loops.optimization.optimizer_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromdataclassesimportdataclass,fieldfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,OrderedDict,Tuple,UnionimporttorchfromtorchimportTensorfromtorch.optimimportOptimizerfrompytorch_lightning.acceleratorsimportTPUAcceleratorfrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.loopsimportLoopfrompytorch_lightning.loops.optimization.closureimportAbstractClosure,OutputResultfrompytorch_lightning.loops.utilitiesimport(_block_parallel_sync_behavior,_build_training_step_kwargs,_extract_hiddens,)frompytorch_lightning.trainer.progressimportOptimizationProgressfrompytorch_lightning.utilitiesimportAMPTypefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.typesimportSTEP_OUTPUTfrompytorch_lightning.utilities.warningsimportWarningCache@dataclassclassClosureResult(OutputResult):"""A container to hold the result of a :class:`Closure` call. It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`. Attributes: closure_loss: The loss with a graph attached. loss: A detached copy of the closure loss. extra: Any keys other than the loss returned. """closure_loss:Optional[Tensor]loss:Optional[Tensor]=field(init=False,default=None)extra:Dict[str,Any]=field(default_factory=dict)def__post_init__(self)->None:self._clone_loss()def_clone_loss(self)->None:ifself.closure_lossisnotNone:# the loss will get scaled for amp. avoid any modifications to itself.loss=self.closure_loss.detach().clone()@classmethoddeffrom_training_step_output(cls,training_step_output:Optional[STEP_OUTPUT],normalize:int=1)->"ClosureResult":closure_loss,extra=None,{}ifisinstance(training_step_output,dict):# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`closure_loss=training_step_output.get("loss")ifclosure_lossisNone:raiseMisconfigurationException("In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present")extra={k:vfork,vintraining_step_output.items()ifknotin("loss","hiddens")}elifisinstance(training_step_output,Tensor):closure_loss=training_step_outputeliftraining_step_outputisnotNone:raiseMisconfigurationException("In automatic optimization, `training_step` must return a Tensor, ""a dict, or None (where the step will be skipped).")ifclosure_lossisnotNone:# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect# note: avoid in-place operation `x /= y` here on purposeclosure_loss=closure_loss/normalizereturncls(closure_loss,extra=extra)defasdict(self)->Dict[str,Any]:return{"loss":self.loss,**self.extra}classClosure(AbstractClosure[ClosureResult]):"""An implementation of a :class:`AbstractClosure` for automatic optimization in Lightning that combines three elementary closures into one: ``training_step``, ``backward`` and ``zero_grad``. The Closure gets created by the training loop(s) and is then passed to the :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally do something with the output. Args: step_fn: This is typically the :meth:`pytorch_lightning.core.module.LightningModule.training_step wrapped with processing for its outputs backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. Can be set to ``None`` to skip the backward operation. zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example when accumulating gradients. Example: closure = Closure() optimizer = torch.optim.Adam(...) optimizer.step(closure) """warning_cache=WarningCache()def__init__(self,step_fn:Callable[[],ClosureResult],backward_fn:Optional[Callable[[Tensor],None]]=None,zero_grad_fn:Optional[Callable[[],None]]=None,):super().__init__()self._step_fn=step_fnself._backward_fn=backward_fnself._zero_grad_fn=zero_grad_fndefclosure(self,*args:Any,**kwargs:Any)->ClosureResult:step_output=self._step_fn()ifstep_output.closure_lossisNone:self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")ifself._zero_grad_fnisnotNone:self._zero_grad_fn()ifself._backward_fnisnotNoneandstep_output.closure_lossisnotNone:self._backward_fn(step_output.closure_loss)returnstep_outputdef__call__(self,*args:Any,**kwargs:Any)->Optional[Tensor]:self._result=self.closure(*args,**kwargs)returnself._result.loss_OUTPUTS_TYPE=Dict[int,Dict[str,Any]]
[docs]classOptimizerLoop(Loop[_OUTPUTS_TYPE]):"""Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization. """output_result_cls=ClosureResultdef__init__(self)->None:super().__init__()self.optim_progress:OptimizationProgress=OptimizationProgress()self._outputs:_OUTPUTS_TYPE={}self._skip_backward:bool=Falseself._optimizers:Tuple[Optimizer,...]=tuple()self._indices:Tuple[int,...]=tuple()self._hiddens:Optional[Any]=None@propertydefoptimizer_idx(self)->int:returnself._indices[self.optim_progress.optimizer_position]@propertydefdone(self)->bool:"""Returns ``True`` when the last optimizer in the sequence has run."""returnself.optim_progress.optimizer_position>=len(self._indices)
[docs]defconnect(self,**kwargs:"Loop")->None:raiseNotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
[docs]defreset(self)->None:ifnotself.restarting:# when reset() is called from outside (manually), we reset the loop progressself.optim_progress.optimizer_position=0else:self.optim_progress.reset_on_restart()self._outputs={}
[docs]defadvance(self,optimizers:List[Tuple[int,Optimizer]],kwargs:OrderedDict)->None:# type: ignore[override]kwargs=self._build_kwargs(kwargs,self.optimizer_idx,self._hiddens)result=self._run_optimization(kwargs,self._optimizers[self.optim_progress.optimizer_position])ifresult.lossisnotNone:# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch# would be skipped otherwiseself._outputs[self.optimizer_idx]=result.asdict()self.optim_progress.optimizer_position+=1
def_run_optimization(self,kwargs:OrderedDict,optimizer:torch.optim.Optimizer)->ClosureResult:"""Runs closure (train step + backward) together with optimization if necessary. Args: kwargs: the kwargs passed down to the hooks. optimizer: the current optimizer """opt_idx=kwargs.get("optimizer_idx",0)# toggle model paramsself._run_optimization_start(opt_idx,optimizer)closure=self._make_closure(kwargs,optimizer)if(# when the strategy handles accumulation, we want to always call the optimizer stepnotself.trainer.strategy.handles_gradient_accumulationandself.trainer.fit_loop._should_accumulate()):# For gradient accumulation# -------------------# calculate loss (train step + train step end)# -------------------# automatic_optimization=True: perform ddp sync only when performing optimizer_stepwith_block_parallel_sync_behavior(self.trainer.strategy,block=True):closure()# ------------------------------# BACKWARD PASS# ------------------------------# gradient update with accumulated gradientselse:# the `batch_idx` is optional with inter-batch parallelismself._optimizer_step(optimizer,opt_idx,kwargs.get("batch_idx",0),closure)result=closure.consume_result()ifresult.lossisnotNone:# if no result, user decided to skip optimization# otherwise update running loss + reset accumulated loss# TODO: find proper way to handle updating running lossself.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)# untoggle model paramsself._run_optimization_end(opt_idx)returnresultdef_make_closure(self,kwargs:OrderedDict,optimizer:Optimizer)->Closure:"""Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`."""opt_idx=kwargs.get("optimizer_idx",0)step_fn=self._make_step_fn(kwargs)backward_fn=self._make_backward_fn(optimizer,opt_idx)zero_grad_fn=self._make_zero_grad_fn(kwargs.get("batch_idx",0),opt_idx,optimizer)returnClosure(step_fn=step_fn,backward_fn=backward_fn,zero_grad_fn=zero_grad_fn)def_make_step_fn(self,kwargs:OrderedDict)->Callable[[],ClosureResult]:"""Build the step function that runs the `training_step` and processes its output."""returnpartial(self._training_step,kwargs)def_make_zero_grad_fn(self,batch_idx:int,opt_idx:int,optimizer:Optimizer)->Optional[Callable[[],None]]:"""Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. """ifself._skip_backward:returnNoneis_first_batch_to_accumulate=batch_idx%self.trainer.accumulate_grad_batches==0ifnotis_first_batch_to_accumulate:returnNonedefzero_grad_fn()->None:self._on_before_zero_grad(optimizer)self._optimizer_zero_grad(batch_idx,optimizer,opt_idx)returnzero_grad_fndef_make_backward_fn(self,optimizer:Optimizer,opt_idx:int)->Optional[Callable[[Tensor],None]]:"""Build a `backward` function that handles back-propagation through the output produced by the `training_step` function. Returns ``None`` in the case backward needs to be skipped. """ifself._skip_backward:returnNonedefbackward_fn(loss:Tensor)->None:self.trainer._call_strategy_hook("backward",loss,optimizer,opt_idx)returnbackward_fndef_run_optimization_start(self,opt_idx:int,optimizer:torch.optim.Optimizer)->None:"""Toggles the optimizer to ensure the correct one is used and prevent dangling grads. Args: opt_idx: the index of the optimizer to use optimizer: the optimizer to use """# make sure only the gradients of the current optimizer's parameters are calculated# in the training step to prevent dangling gradients in multiple-optimizer setup.iflen(self.trainer.optimizers)>1:model=self.trainer.lightning_modulemodel.toggle_optimizer(optimizer,opt_idx)def_run_optimization_end(self,opt_idx:int)->None:iflen(self.trainer.optimizers)>1:model=self.trainer.lightning_modulemodel.untoggle_optimizer(opt_idx)def_optimizer_step(self,optimizer:Union[Optimizer,LightningOptimizer],opt_idx:int,batch_idx:int,train_step_and_backward_closure:Callable[[],Optional[Tensor]],)->None:"""Performs the optimizer step and some sanity checking. Args: optimizer: the optimizer to perform the step with opt_idx: the index of the current :param:`optimizer` batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default, called by the optimizer (if possible) """is_lbfgs=isinstance(optimizer,torch.optim.LBFGS)# wraps into LightningOptimizer only for running stepifself.trainer.amp_backend==AMPType.APEX:# apex overrides .step function and need to be wrapped on each stepoptimizer=LightningOptimizer._to_lightning_optimizer(optimizer,self.trainer.strategy,opt_idx)else:optimizer=self.trainer.strategy._lightning_optimizers[opt_idx]# if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we# need to check again if `should_accumulate` before increasing the countersshould_accumulate=self.trainer.fit_loop._should_accumulate()ifnotshould_accumulate:self.optim_progress.optimizer.step.increment_ready()# model hookself.trainer._call_lightning_module_hook("optimizer_step",self.trainer.current_epoch,batch_idx,optimizer,opt_idx,train_step_and_backward_closure,on_tpu=isinstance(self.trainer.accelerator,TPUAccelerator),using_native_amp=(self.trainer.amp_backend==AMPType.NATIVE),using_lbfgs=is_lbfgs,)ifnotshould_accumulate:self.optim_progress.optimizer.step.increment_completed()def_on_before_zero_grad(self,optimizer:torch.optim.Optimizer)->None:"""Calls the ``on_before_zero_grad`` hook. Args: optimizer: the current optimizer """self.optim_progress.optimizer.zero_grad.increment_ready()self.trainer._call_callback_hooks("on_before_zero_grad",optimizer)self.trainer._call_lightning_module_hook("on_before_zero_grad",optimizer)self.optim_progress.optimizer.zero_grad.increment_started()def_optimizer_zero_grad(self,batch_idx:int,optimizer:torch.optim.Optimizer,opt_idx:int)->None:"""Zeroes out all gradients of parameters optimized by the current optimizer. Args: batch_idx: the index of the current batch optimizer: the current optimizer opt_idx: the index of the current optimizer """self.trainer._call_lightning_module_hook("optimizer_zero_grad",self.trainer.current_epoch,batch_idx,optimizer,opt_idx)self.optim_progress.optimizer.zero_grad.increment_completed()def_training_step(self,kwargs:OrderedDict)->ClosureResult:"""Performs the actual train step with the tied hooks. Args: kwargs: the kwargs passed down to the hooks. Returns: A ``ClosureResult`` containing the training step output. """# manually capture logged metricstraining_step_output=self.trainer._call_strategy_hook("training_step",*kwargs.values())self.trainer.strategy.post_training_step()model_output=self.trainer._call_lightning_module_hook("training_step_end",training_step_output)strategy_output=self.trainer._call_strategy_hook("training_step_end",training_step_output)training_step_output=strategy_outputifmodel_outputisNoneelsemodel_outputself._hiddens=_extract_hiddens(training_step_output,self.trainer.lightning_module.truncated_bptt_steps)result=self.output_result_cls.from_training_step_output(training_step_output,self.trainer.accumulate_grad_batches)ifself.trainer.move_metrics_to_cpu:# hiddens and the training step output are not moved as they are not considered "metrics"assertself.trainer._resultsisnotNoneself.trainer._results.cpu()returnresultdef_build_kwargs(self,kwargs:OrderedDict,opt_idx:int,hiddens:Optional[Any])->OrderedDict:"""Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. opt_idx: the index of the current optimizer. hiddens: the hidden state of the previous RNN iteration. Returns: The kwargs passed down to the hooks. """return_build_training_step_kwargs(kwargs,self.trainer.lightning_module,self.trainer.optimizers,opt_idx,hiddens)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.