Source code for pytorch_lightning.loops.optimization.optimizer_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromdataclassesimportdataclass,fieldfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,Tuple,UnionimporttorchfromtorchimportTensorfromtorch.optimimportOptimizerfrompytorch_lightning.acceleratorsimportTPUAcceleratorfrompytorch_lightning.core.optimizerimportLightningOptimizerfrompytorch_lightning.loopsimportLoopfrompytorch_lightning.loops.optimization.closureimportAbstractClosure,OutputResultfrompytorch_lightning.loops.utilitiesimport(_block_parallel_sync_behavior,_build_training_step_kwargs,_extract_hiddens,check_finite_loss,)frompytorch_lightning.trainer.progressimportOptimizationProgressfrompytorch_lightning.utilitiesimportAMPTypefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.finite_checksimportdetect_nan_parametersfrompytorch_lightning.utilities.typesimportSTEP_OUTPUTfrompytorch_lightning.utilities.warningsimportWarningCache@dataclassclassClosureResult(OutputResult):"""A container to hold the result of a :class:`Closure` call. It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. Attributes: closure_loss: The loss with a graph attached. loss: A detached copy of the closure loss. extra: Any keys other than the loss returned. """closure_loss:Optional[Tensor]loss:Optional[Tensor]=field(init=False,default=None)extra:Dict[str,Any]=field(default_factory=dict)def__post_init__(self)->None:self._clone_loss()def_clone_loss(self)->None:ifself.closure_lossisnotNone:# the loss will get scaled for amp. avoid any modifications to itself.loss=self.closure_loss.detach().clone()@classmethoddeffrom_training_step_output(cls,training_step_output:Optional[STEP_OUTPUT],normalize:int=1)->"ClosureResult":closure_loss,extra=None,{}ifisinstance(training_step_output,dict):# this should not modify the `training_step_output`, as the user could be using it after `training_step_end`closure_loss=training_step_output.get("loss")ifclosure_lossisNone:raiseMisconfigurationException("In automatic_optimization, when `training_step` returns a dict, the 'loss' key needs to be present")extra={k:vfork,vintraining_step_output.items()ifknotin("loss","hiddens")}elifisinstance(training_step_output,Tensor):closure_loss=training_step_outputeliftraining_step_outputisnotNone:raiseMisconfigurationException("In automatic optimization, `training_step` must return a Tensor, ""a dict, or None (where the step will be skipped).")ifclosure_lossisnotNone:# accumulate the loss. If ``accumulate_grad_batches == 1``, no effect# note: avoid in-place operation `x /= y` here on purposeclosure_loss=closure_loss/normalizereturncls(closure_loss,extra=extra)defasdict(self)->Dict[str,Any]:return{"loss":self.loss,**self.extra}classClosure(AbstractClosure[ClosureResult]):"""An implementation of a :class:`AbstractClosure` for automatic optimization in Lightning that combines three elementary closures into one: ``training_step``, ``backward`` and ``zero_grad``. The Closure gets created by the training loop(s) and is then passed to the :meth:`torch.optim.Optimizer.step` method. An optimizer is responsible for calling the closure and optionally do something with the output. Args: step_fn: This is typically the :meth:`pytorch_lightning.core.lightning.LightningModule.training_step wrapped with processing for its outputs backward_fn: A function that takes a loss value as input, performs back-propagation and returns the loss value. Can be set to ``None`` to skip the backward operation. zero_grad_fn: A function that zeroes the gradients. Can be set to ``None`` to skip zero_grad, for example when accumulating gradients. Example: closure = Closure() optimizer = torch.optim.Adam(...) optimizer.step(closure) """warning_cache=WarningCache()def__init__(self,step_fn:Callable[[],ClosureResult],backward_fn:Optional[Callable[[Tensor],None]]=None,zero_grad_fn:Optional[Callable[[],None]]=None,):super().__init__()self._step_fn=step_fnself._backward_fn=backward_fnself._zero_grad_fn=zero_grad_fndefclosure(self,*args:Any,**kwargs:Any)->ClosureResult:step_output=self._step_fn()ifstep_output.closure_lossisNone:self.warning_cache.warn("`training_step` returned `None`. If this was on purpose, ignore this warning...")ifself._zero_grad_fnisnotNone:self._zero_grad_fn()ifself._backward_fnisnotNoneandstep_output.closure_lossisnotNone:self._backward_fn(step_output.closure_loss)returnstep_outputdef__call__(self,*args:Any,**kwargs:Any)->Optional[Tensor]:self._result=self.closure(*args,**kwargs)returnself._result.loss_OUTPUTS_TYPE=Dict[int,Dict[str,Any]]
[docs]classOptimizerLoop(Loop[_OUTPUTS_TYPE]):"""Runs over a sequence of optimizers. This loop implements what is known in Lightning as Automatic Optimization. """output_result_cls=ClosureResultdef__init__(self)->None:super().__init__()self.optim_progress:OptimizationProgress=OptimizationProgress()self._outputs:_OUTPUTS_TYPE={}self._skip_backward:bool=Falseself._batch_idx:int=0self._optimizers:Tuple[Optimizer,...]=tuple()self._indices:Tuple[int,...]=tuple()self._hiddens:Optional[Any]=None@propertydefoptimizer_idx(self)->int:returnself._indices[self.optim_progress.optimizer_position]@propertydefdone(self)->bool:"""Returns ``True`` when the last optimizer in the sequence has run."""returnself.optim_progress.optimizer_position>=len(self._indices)
[docs]defconnect(self,**kwargs:"Loop")->None:raiseNotImplementedError(f"{self.__class__.__name__} does not connect any child loops.")
[docs]defreset(self)->None:ifnotself.restarting:# when reset() is called from outside (manually), we reset the loop progressself.optim_progress.optimizer_position=0else:self.optim_progress.reset_on_restart()self._outputs={}
[docs]defadvance(self,batch:Any,*args:Any,**kwargs:Any)->None:# type: ignore[override]result=self._run_optimization(batch,self._batch_idx,self._optimizers[self.optim_progress.optimizer_position],self.optimizer_idx,)ifresult.lossisnotNone:# automatic optimization assumes a loss needs to be returned for extras to be considered as the batch# would be skipped otherwiseself._outputs[self.optimizer_idx]=result.asdict()self.optim_progress.optimizer_position+=1
def_run_optimization(self,split_batch:Any,batch_idx:int,optimizer:torch.optim.Optimizer,opt_idx:int)->ClosureResult:"""Runs closure (train step + backward) together with optimization if necessary. Args: split_batch: the current tbptt split of the whole batch batch_idx: the index of the current batch optimizer: the current optimizer opt_idx: the index of the current optimizer """# toggle model paramsself._run_optimization_start(opt_idx,optimizer)closure=self._make_closure(split_batch,batch_idx,opt_idx,optimizer)if(# when the strategy handles accumulation, we want to always call the optimizer stepnotself.trainer.strategy.handles_gradient_accumulationandself.trainer.fit_loop._should_accumulate()):# For gradient accumulation# -------------------# calculate loss (train step + train step end)# -------------------# automatic_optimization=True: perform ddp sync only when performing optimizer_stepwith_block_parallel_sync_behavior(self.trainer.strategy,block=True):closure()# ------------------------------# BACKWARD PASS# ------------------------------# gradient update with accumulated gradientselse:self._optimizer_step(optimizer,opt_idx,batch_idx,closure)result=closure.consume_result()ifresult.lossisnotNone:# if no result, user decided to skip optimization# otherwise update running loss + reset accumulated loss# TODO: find proper way to handle updating running lossself.trainer.fit_loop.epoch_loop.batch_loop._update_running_loss(result.loss)# untoggle model paramsself._run_optimization_end(opt_idx)returnresultdef_make_closure(self,split_batch:Any,batch_idx:int,opt_idx:int,optimizer:Optimizer)->Closure:"""Build a closure object that captures the given arguments and runs the `training_step` function and optionally other functions such as `backward` and `zero_grad`."""step_fn=self._make_step_fn(split_batch,batch_idx,opt_idx)backward_fn=self._make_backward_fn(optimizer,opt_idx)zero_grad_fn=self._make_zero_grad_fn(batch_idx,opt_idx,optimizer)returnClosure(step_fn=step_fn,backward_fn=backward_fn,zero_grad_fn=zero_grad_fn)def_make_step_fn(self,split_batch:Any,batch_idx:int,opt_idx:int)->Callable[[],ClosureResult]:"""Build the step function that runs the `training_step` and processes its output."""returnpartial(self._training_step,split_batch,batch_idx,opt_idx)def_make_zero_grad_fn(self,batch_idx:int,opt_idx:int,optimizer:Optimizer)->Optional[Callable[[],None]]:"""Build a `zero_grad` function that zeroes the gradients before back-propagation. Returns ``None`` in the case backward needs to be skipped. """ifself._skip_backward:returnNoneis_first_batch_to_accumulate=batch_idx%self.trainer.accumulate_grad_batches==0ifnotis_first_batch_to_accumulate:returnNonedefzero_grad_fn()->None:self._on_before_zero_grad(optimizer)self._optimizer_zero_grad(batch_idx,optimizer,opt_idx)returnzero_grad_fndef_make_backward_fn(self,optimizer:Optimizer,opt_idx:int)->Optional[Callable[[Tensor],None]]:"""Build a `backward` function that handles back-propagation through the output produced by the `training_step` function. Returns ``None`` in the case backward needs to be skipped. """ifself._skip_backward:returnNonedefbackward_fn(loss:Tensor)->None:self.trainer._call_strategy_hook("backward",loss,optimizer,opt_idx)# check if model weights are nanifself.trainer._terminate_on_nan:detect_nan_parameters(self.trainer.lightning_module)returnbackward_fndef_run_optimization_start(self,opt_idx:int,optimizer:torch.optim.Optimizer)->None:"""Toggles the optimizer to ensure the correct one is used and prevent dangling grads. Args: opt_idx: the index of the optimizer to use optimizer: the optimizer to use """# make sure only the gradients of the current optimizer's parameters are calculated# in the training step to prevent dangling gradients in multiple-optimizer setup.iflen(self.trainer.optimizers)>1:model=self.trainer.lightning_modulemodel.toggle_optimizer(optimizer,opt_idx)def_run_optimization_end(self,opt_idx:int)->None:iflen(self.trainer.optimizers)>1:model=self.trainer.lightning_modulemodel.untoggle_optimizer(opt_idx)def_optimizer_step(self,optimizer:Union[Optimizer,LightningOptimizer],opt_idx:int,batch_idx:int,train_step_and_backward_closure:Callable[[],Optional[Tensor]],)->None:"""Performs the optimizer step and some sanity checking. Args: optimizer: the optimizer to perform the step with opt_idx: the index of the current :param:`optimizer` batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the gradients. By default, called by the optimizer (if possible) """is_lbfgs=isinstance(optimizer,torch.optim.LBFGS)# wraps into LightningOptimizer only for running stepifself.trainer.amp_backend==AMPType.APEX:# apex overrides .step function and need to be wrapped on each stepoptimizer=LightningOptimizer._to_lightning_optimizer(optimizer,self.trainer.strategy,opt_idx)else:optimizer=self.trainer.strategy._lightning_optimizers[opt_idx]# if `strategy.handles_gradient_accumulation`, this method will be called to route into the strategy, but we# need to check again if `should_accumulate` before increasing the countersshould_accumulate=self.trainer.fit_loop._should_accumulate()ifnotshould_accumulate:self.optim_progress.optimizer.step.increment_ready()# model hookself.trainer._call_lightning_module_hook("optimizer_step",self.trainer.current_epoch,batch_idx,optimizer,opt_idx,train_step_and_backward_closure,on_tpu=isinstance(self.trainer.accelerator,TPUAccelerator),using_native_amp=(self.trainer.amp_backend==AMPType.NATIVE),using_lbfgs=is_lbfgs,)ifnotshould_accumulate:self.optim_progress.optimizer.step.increment_completed()def_on_before_zero_grad(self,optimizer:torch.optim.Optimizer)->None:"""Calls the ``on_before_zero_grad`` hook. Args: optimizer: the current optimizer """self.optim_progress.optimizer.zero_grad.increment_ready()self.trainer._call_callback_hooks("on_before_zero_grad",optimizer)self.trainer._call_lightning_module_hook("on_before_zero_grad",optimizer)self.optim_progress.optimizer.zero_grad.increment_started()def_optimizer_zero_grad(self,batch_idx:int,optimizer:torch.optim.Optimizer,opt_idx:int)->None:"""Zeroes out all gradients of parameters optimized by the current optimizer. Args: batch_idx: the index of the current batch optimizer: the current optimizer opt_idx: the index of the current optimizer """self.trainer._call_lightning_module_hook("optimizer_zero_grad",self.trainer.current_epoch,batch_idx,optimizer,opt_idx)self.optim_progress.optimizer.zero_grad.increment_completed()def_training_step(self,split_batch:Any,batch_idx:int,opt_idx:int)->ClosureResult:"""Performs the actual train step with the tied hooks. Args: split_batch: the current tbptt split of the current batch batch_idx: the index of the current batch opt_idx: the index of the current optimizer Returns: A ``ClosureResult`` containing the training step output. """# give the PL module a result for logginglightning_module=self.trainer.lightning_modulestep_kwargs=_build_training_step_kwargs(lightning_module,self.trainer.optimizers,split_batch,batch_idx,opt_idx,self._hiddens)# manually capture logged metricstraining_step_output=self.trainer._call_strategy_hook("training_step",*step_kwargs.values())self.trainer.strategy.post_training_step()model_output=self.trainer._call_lightning_module_hook("training_step_end",training_step_output)strategy_output=self.trainer._call_strategy_hook("training_step_end",training_step_output)training_step_output=strategy_outputifmodel_outputisNoneelsemodel_outputself._hiddens=_extract_hiddens(training_step_output,lightning_module.truncated_bptt_steps)result=self.output_result_cls.from_training_step_output(training_step_output,self.trainer.accumulate_grad_batches)ifself.trainer._terminate_on_nan:check_finite_loss(result.closure_loss)ifself.trainer.move_metrics_to_cpu:# hiddens and the training step output are not moved as they are not considered "metrics"assertself.trainer._resultsisnotNoneself.trainer._results.cpu()returnresult
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.