Source code for pytorch_lightning.loops.optimization.manual_loop
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcollectionsimportOrderedDictfromdataclassesimportdataclass,fieldfromtypingimportAny,Dict,OptionalfromtorchimportTensorfrompytorch_lightning.core.optimizerimportdo_nothing_closurefrompytorch_lightning.loopsimportLoopfrompytorch_lightning.loops.optimization.closureimportOutputResultfrompytorch_lightning.loops.utilitiesimport_build_training_step_kwargs,_extract_hiddensfrompytorch_lightning.trainer.progressimportProgress,ReadyCompletedTrackerfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.typesimportSTEP_OUTPUT@dataclassclassManualResult(OutputResult):"""A container to hold the result returned by the ``ManualLoop``. It is created from the output of :meth:`~pytorch_lightning.core.module.LightningModule.training_step`. Attributes: extra: Anything returned by the ``training_step``. """extra:Dict[str,Any]=field(default_factory=dict)@classmethoddeffrom_training_step_output(cls,training_step_output:Optional[STEP_OUTPUT])->"ManualResult":extra={}ifisinstance(training_step_output,dict):extra={k:vfork,vintraining_step_output.items()ifk!="hiddens"}elifisinstance(training_step_output,Tensor):extra={"loss":training_step_output}eliftraining_step_outputisnotNone:raiseMisconfigurationException("In manual optimization, `training_step` must either return a Tensor, ""a dict with extras to pass to `training_epoch_end` or have no return.")if"loss"inextra:# we detach manually as it's expected that it will have a `grad_fn`extra["loss"]=extra["loss"].detach()returncls(extra=extra)defasdict(self)->Dict[str,Any]:returnself.extra_OUTPUTS_TYPE=Dict[str,Any]
[docs]classManualOptimization(Loop[_OUTPUTS_TYPE]):"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the :meth:`~pytorch_lightning.core.module.LightningModule.training_step` and therefore the user is responsible for back-propagating gradients and making calls to the optimizers. This loop is a trivial case because it performs only a single iteration (calling directly into the module's :meth:`~pytorch_lightning.core.module.LightningModule.training_step`) and passing through the output(s). """output_result_cls=ManualResultdef__init__(self)->None:super().__init__()# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than# `OptimizationProgress`self.optim_step_progress=Progress.from_defaults(ReadyCompletedTracker)self._done:bool=Falseself._hiddens:Optional[Any]=Noneself._output:_OUTPUTS_TYPE={}@propertydefdone(self)->bool:returnself._done
[docs]defon_run_start(self,*_:Any,**__:Any)->None:# inject logic around the optimizer stepfori,lightning_optimizerinself.trainer.strategy._lightning_optimizers.items():lightning_optimizer._on_before_step=self._on_before_steplightning_optimizer._on_after_step=self._on_after_step
[docs]defadvance(self,kwargs:OrderedDict)->None:"""Performs the training step for manual optimization. Args: kwargs: The kwargs passed down to the hooks. """kwargs=self._build_kwargs(kwargs,self._hiddens)# manually capture logged metricstraining_step_output=self.trainer._call_strategy_hook("training_step",*kwargs.values())delkwargs# release the batch from memoryself.trainer.strategy.post_training_step()model_output=self.trainer._call_lightning_module_hook("training_step_end",training_step_output)strategy_output=self.trainer._call_strategy_hook("training_step_end",training_step_output)training_step_output=strategy_outputifmodel_outputisNoneelsemodel_outputself._hiddens=_extract_hiddens(training_step_output,self.trainer.lightning_module.truncated_bptt_steps)result=self.output_result_cls.from_training_step_output(training_step_output)ifself.trainer.move_metrics_to_cpu:# hiddens and the training step output are not moved as they are not considered "metrics"# the user might need them on the correct device for an operation in `training_epoch_end`assertself.trainer._resultsisnotNoneself.trainer._results.cpu()self._done=Trueself._output=result.asdict()
[docs]defon_run_end(self)->_OUTPUTS_TYPE:"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""output,self._output=self._output,{}# free memory# reset logic around the optimizer stepfori,lightning_optimizerinself.trainer.strategy._lightning_optimizers.items():lightning_optimizer._on_before_step=do_nothing_closurelightning_optimizer._on_after_step=do_nothing_closurereturnoutput
def_on_before_step(self)->None:self.optim_step_progress.increment_ready()self.trainer.profiler.start("optimizer_step")def_on_after_step(self)->None:self.trainer.profiler.stop("optimizer_step")self.optim_step_progress.increment_completed()def_build_kwargs(self,kwargs:OrderedDict,hiddens:Optional[Any])->OrderedDict:"""Helper method to build the arguments for the current step. Args: kwargs: The kwargs passed down to the hooks. hiddens: the hidden state of the previous RNN iteration. Returns: The kwargs passed down to the hooks. """return_build_training_step_kwargs(kwargs,self.trainer.lightning_module,self.trainer.optimizers,None,hiddens)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.