Source code for pytorch_lightning.plugins.precision.precision_plugin
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,Generator,List,Optional,Tuple,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.core.hooksimportCheckpointHooksfrompytorch_lightning.utilitiesimportgrad_norm,GradClipAlgorithmTypefrompytorch_lightning.utilities.typesimport_PARAMETERS
[docs]classPrecisionPlugin(CheckpointHooks):"""Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """precision:Union[str,int]=32
[docs]defmain_params(self,optimizer:Optimizer)->_PARAMETERS:"""The main params of the model. Returns the plain model params here. Maybe different in other precision plugins. """forgroupinoptimizer.param_groups:yield fromgroup["params"]
[docs]defconnect(self,model:Module,optimizers:List[Optimizer],lr_schedulers:List[Any])->Tuple[Module,List[Optimizer],List[Any]]:"""Connects this plugin to the accelerator and the training process."""returnmodel,optimizers,lr_schedulers
[docs]defpre_backward(self,model:"pl.LightningModule",closure_loss:Tensor)->Tensor:"""Run before precision plugin executes backward. Args: model: the model to be optimized closure_loss: the loss value obtained from the closure """model.trainer._call_callback_hooks("on_before_backward",closure_loss)model.trainer._call_lightning_module_hook("on_before_backward",closure_loss)returnclosure_loss
[docs]defbackward(self,model:"pl.LightningModule",closure_loss:Tensor,optimizer:Optional[Optimizer],optimizer_idx:Optional[int],*args:Any,**kwargs:Any,)->None:"""Performs the actual backpropagation. Args: model: the model to be optimized closure_loss: the loss value obtained from the closure optimizer: current optimizer being used. ``None`` if using manual optimization """# do backward passifmodelisnotNoneandisinstance(model,pl.LightningModule):model.backward(closure_loss,optimizer,optimizer_idx,*args,**kwargs)else:self._run_backward(closure_loss,*args,**kwargs)
[docs]defpost_backward(self,model:"pl.LightningModule",closure_loss:Tensor)->Tensor:"""Run after precision plugin executes backward. Args: model: the model to be optimized closure_loss: the loss value obtained from the closure """# once backward has been applied, release graphclosure_loss=closure_loss.detach()model.trainer._call_callback_hooks("on_after_backward")model.trainer._call_lightning_module_hook("on_after_backward")returnclosure_loss
def_run_backward(self,tensor:Tensor,model:Optional[Module],*args:Any,**kwargs:Any)->None:"""Lightning-independent backward logic. Currently only used by Lightning Lite. Subject to further refactors. """tensor.backward(*args,**kwargs)def_after_closure(self,model:Optional[Union["pl.LightningModule",Module]],optimizer:Optimizer,optimizer_idx:int)->None:"""Utility to share some code after the closure has been run."""ifnotisinstance(model,pl.LightningModule):# none of this applies to Litereturntrainer=model.trainertrainer._call_callback_hooks("on_before_optimizer_step",optimizer,optimizer_idx)trainer._call_lightning_module_hook("on_before_optimizer_step",optimizer,optimizer_idx)# TODO: this is done for the entire model but should be changed to per-optimizerifoptimizer_idx==0:self._track_grad_norm(trainer)self._clip_gradients(model,optimizer,optimizer_idx,trainer.gradient_clip_val,gradient_clip_algorithm=trainer.gradient_clip_algorithm,)def_wrap_closure(self,model:"pl.LightningModule",optimizer:Optimizer,optimizer_idx:int,closure:Callable[[],Any],)->Any:"""This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step`` hook is called. The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """closure_result=closure()self._after_closure(model,optimizer,optimizer_idx)returnclosure_result
[docs]defoptimizer_step(self,model:Optional[Union["pl.LightningModule",Module]],optimizer:Optimizer,optimizer_idx:int,closure:Callable[[],Any],**kwargs:Any,)->Any:"""Hook to run the optimizer step."""ifisinstance(model,pl.LightningModule):closure=partial(self._wrap_closure,model,optimizer,optimizer_idx,closure)returnoptimizer.step(closure=closure,**kwargs)
def_track_grad_norm(self,trainer:"pl.Trainer")->None:iftrainer.track_grad_norm==-1:returnkwargs={}iflen(trainer.loggers)==1:kwargs["group_separator"]=trainer.loggers[0].group_separatorgrad_norm_dict=grad_norm(trainer.lightning_module,trainer.track_grad_norm,**kwargs)ifgrad_norm_dict:prev_fx=trainer.lightning_module._current_fx_nametrainer.lightning_module._current_fx_name="on_before_optimizer_step"trainer.lightning_module.log_grad_norm(grad_norm_dict)trainer.lightning_module._current_fx_name=prev_fxdef_clip_gradients(self,model:Union["pl.LightningModule",Module],optimizer:Optimizer,optimizer_idx:int,clip_val:Optional[Union[int,float]]=None,gradient_clip_algorithm:Optional[GradClipAlgorithmType]=None,)->None:ifnotisinstance(model,pl.LightningModule)ornotmodel.automatic_optimization:# the configuration validator disallows clipping on manualreturnmodel.trainer._call_lightning_module_hook("configure_gradient_clipping",optimizer,optimizer_idx,gradient_clip_val=clip_val,gradient_clip_algorithm=gradient_clip_algorithm,)
[docs]defclip_gradients(self,optimizer:Optimizer,clip_val:Union[int,float]=0.0,gradient_clip_algorithm:GradClipAlgorithmType=GradClipAlgorithmType.NORM,)->None:"""Clips the gradients."""ifclip_val<=0:returnifgradient_clip_algorithm==GradClipAlgorithmType.VALUE:self.clip_grad_by_value(optimizer,clip_val)elifgradient_clip_algorithm==GradClipAlgorithmType.NORM:self.clip_grad_by_norm(optimizer,clip_val)
[docs]defclip_grad_by_value(self,optimizer:Optimizer,clip_val:Union[int,float])->None:"""Clip gradients by value."""parameters=self.main_params(optimizer)torch.nn.utils.clip_grad_value_(parameters,clip_value=clip_val)
[docs]defclip_grad_by_norm(self,optimizer:Optimizer,clip_val:Union[int,float])->None:"""Clip gradients by norm."""parameters=self.main_params(optimizer)torch.nn.utils.clip_grad_norm_(parameters,clip_val)
[docs]defdispatch(self,trainer:"pl.Trainer")->None:"""Hook to do something when ``Strategy.dispatch()`` gets called."""
[docs]@contextlib.contextmanagerdefforward_context(self)->Generator[None,None,None]:"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""yield
[docs]@contextlib.contextmanagerdeftrain_step_context(self)->Generator[None,None,None]:"""A contextmanager for the training step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdefval_step_context(self)->Generator[None,None,None]:"""A contextmanager for the validation step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdeftest_step_context(self)->Generator[None,None,None]:"""A contextmanager for the test step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdefpredict_step_context(self)->Generator[None,None,None]:"""A contextmanager for the predict step."""withself.forward_context():yield
[docs]defteardown(self)->None:"""This method is called to teardown the training process. It is the right place to release memory and free other resources. """
[docs]defstate_dict(self)->Dict[str,Any]:"""Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: A dictionary containing precision plugin state. """return{}
[docs]defload_state_dict(self,state_dict:Dict[str,Any])->None:"""Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. Args: state_dict: the precision plugin state returned by ``state_dict``. """pass
[docs]defon_save_checkpoint(self,checkpoint:Dict[str,Any])->None:"""``PrecisionPlugin.on_save_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. Use ``state_dict`` instead. """
[docs]defon_load_checkpoint(self,checkpoint:Dict[str,Any])->None:"""``PrecisionPlugin.on_load_checkpoint`` was deprecated in v1.6 and will be removed in v1.8. Use ``load_state_dict`` instead. """
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.