Source code for pytorch_lightning.loops.optimization.manual_loop
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromdataclassesimportdataclass,fieldfromtypingimportAny,Dict,OptionalfromtorchimportTensorfrompytorch_lightning.core.optimizerimportdo_nothing_closurefrompytorch_lightning.loopsimportLoopfrompytorch_lightning.loops.optimization.closureimportOutputResultfrompytorch_lightning.loops.utilitiesimport_build_training_step_kwargs,_extract_hiddensfrompytorch_lightning.trainer.progressimportProgress,ReadyCompletedTrackerfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.typesimportSTEP_OUTPUT@dataclassclassManualResult(OutputResult):"""A container to hold the result returned by the ``ManualLoop``. It is created from the output of :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`. Attributes: extra: Anything returned by the ``training_step``. """extra:Dict[str,Any]=field(default_factory=dict)@classmethoddeffrom_training_step_output(cls,training_step_output:Optional[STEP_OUTPUT])->"ManualResult":extra={}ifisinstance(training_step_output,dict):extra={k:vfork,vintraining_step_output.items()ifk!="hiddens"}elifisinstance(training_step_output,Tensor):extra={"loss":training_step_output}eliftraining_step_outputisnotNone:raiseMisconfigurationException("In manual optimization, `training_step` must either return a Tensor, ""a dict with extras to pass to `training_epoch_end` or have no return.")if"loss"inextra:# we detach manually as it's expected that it will have a `grad_fn`extra["loss"]=extra["loss"].detach()returncls(extra=extra)defasdict(self)->Dict[str,Any]:returnself.extra_OUTPUTS_TYPE=Dict[str,Any]
[docs]classManualOptimization(Loop[_OUTPUTS_TYPE]):"""A special loop implementing what is known in Lightning as Manual Optimization where the optimization happens entirely in the :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step` and therefore the user is responsible for back-propagating gradients and making calls to the optimizers. This loop is a trivial case because it performs only a single iteration (calling directly into the module's :meth:`~pytorch_lightning.core.lightning.LightningModule.training_step`) and passing through the output(s). """output_result_cls=ManualResultdef__init__(self)->None:super().__init__()# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than# `OptimizationProgress`self.optim_step_progress=Progress.from_defaults(ReadyCompletedTracker)self._done:bool=Falseself._hiddens:Optional[Any]=Noneself._output:_OUTPUTS_TYPE={}@propertydefdone(self)->bool:returnself._done
[docs]defon_run_start(self,*_:Any,**__:Any)->None:# inject logic around the optimizer stepfori,lightning_optimizerinself.trainer.strategy._lightning_optimizers.items():lightning_optimizer._on_before_step=self._on_before_steplightning_optimizer._on_after_step=self._on_after_step
[docs]defadvance(self,batch:Any,batch_idx:int)->None:# type: ignore[override]"""Performs the training step for manual optimization. Args: batch: the current tbptt split of the current batch batch_idx: the index of the current batch """assertself.trainerisnotNonelightning_module=self.trainer.lightning_modulestep_kwargs=_build_training_step_kwargs(lightning_module,self.trainer.optimizers,batch,batch_idx,opt_idx=None,hiddens=self._hiddens)# manually capture logged metricstraining_step_output=self.trainer._call_strategy_hook("training_step",*step_kwargs.values())self.trainer.strategy.post_training_step()delstep_kwargsmodel_output=self.trainer._call_lightning_module_hook("training_step_end",training_step_output)strategy_output=self.trainer._call_strategy_hook("training_step_end",training_step_output)training_step_output=strategy_outputifmodel_outputisNoneelsemodel_outputself._hiddens=_extract_hiddens(training_step_output,lightning_module.truncated_bptt_steps)result=self.output_result_cls.from_training_step_output(training_step_output)ifself.trainer.move_metrics_to_cpu:# hiddens and the training step output are not moved as they are not considered "metrics"# the user might need them on the correct device for an operation in `training_epoch_end`assertself.trainer._resultsisnotNoneself.trainer._results.cpu()self._done=Trueself._output=result.asdict()
[docs]defon_run_end(self)->_OUTPUTS_TYPE:"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""output,self._output=self._output,{}# free memory# reset logic around the optimizer stepfori,lightning_optimizerinself.trainer.strategy._lightning_optimizers.items():lightning_optimizer._on_before_step=do_nothing_closurelightning_optimizer._on_after_step=do_nothing_closurereturnoutput
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.