Source code for pytorch_lightning.plugins.precision.precision_plugin
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibfromfunctoolsimportpartialfromtypingimportAny,Callable,Generator,List,Optional,Tuple,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerimportpytorch_lightningasplfromlightning_fabric.pluginsimportPrecisionasFabricPrecisionfromlightning_fabric.utilities.typesimportSteppablefrompytorch_lightning.core.hooksimportCheckpointHooksfrompytorch_lightning.utilitiesimportgrad_norm,GradClipAlgorithmType
[docs]classPrecisionPlugin(FabricPrecision,CheckpointHooks):"""Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """
[docs]defconnect(self,model:Module,optimizers:List[Optimizer],lr_schedulers:List[Any])->Tuple[Module,List[Optimizer],List[Any]]:"""Connects this plugin to the accelerator and the training process."""returnmodel,optimizers,lr_schedulers
[docs]defbackward(# type: ignore[override]self,tensor:Tensor,model:"pl.LightningModule",optimizer:Optional[Steppable],optimizer_idx:Optional[int],*args:Any,**kwargs:Any,)->None:r"""Performs the actual backpropagation. Args: tensor: the loss value obtained from the closure model: the model to be optimized optimizer: current optimizer being used. ``None`` if using manual optimization optimizer_idx: the index of the current optimizer. ``None`` if using manual optimization \*args: Positional arguments intended for the actual function that performs the backward, like :meth:`~torch.Tensor.backward`. \**kwargs: Keyword arguments for the same purpose as ``*args``. """model.backward(tensor,optimizer,optimizer_idx,*args,**kwargs)
[docs]defpost_backward(self,tensor:Tensor,module:"pl.LightningModule")->Tensor:# type: ignore[override]# once backward has been applied, release graphclosure_loss=tensor.detach()module.trainer._call_callback_hooks("on_after_backward")module.trainer._call_lightning_module_hook("on_after_backward")returnclosure_loss
def_after_closure(self,model:"pl.LightningModule",optimizer:Steppable,optimizer_idx:int)->None:"""Utility to share some code after the closure has been run."""trainer=model.trainertrainer._call_callback_hooks("on_before_optimizer_step",optimizer,optimizer_idx)trainer._call_lightning_module_hook("on_before_optimizer_step",optimizer,optimizer_idx)# TODO: this is done for the entire model but should be changed to per-optimizerifoptimizer_idx==0:self._track_grad_norm(trainer)self._clip_gradients(model,optimizer,optimizer_idx,trainer.gradient_clip_val,gradient_clip_algorithm=trainer.gradient_clip_algorithm,)def_wrap_closure(self,model:"pl.LightningModule",optimizer:Optimizer,optimizer_idx:int,closure:Callable[[],Any],)->Any:"""This double-closure allows makes sure the ``closure`` is executed before the ``on_before_optimizer_step`` hook is called. The closure (generally) runs ``backward`` so this allows inspecting gradients in this hook. This structure is consistent with the ``PrecisionPlugin`` subclasses that cannot pass ``optimizer.step(closure)`` directly. """closure_result=closure()self._after_closure(model,optimizer,optimizer_idx)returnclosure_result
[docs]defoptimizer_step(# type: ignore[override]self,optimizer:Steppable,model:"pl.LightningModule",optimizer_idx:int,closure:Callable[[],Any],**kwargs:Any,)->Any:"""Hook to run the optimizer step."""closure=partial(self._wrap_closure,model,optimizer,optimizer_idx,closure)returnoptimizer.step(closure=closure,**kwargs)
def_track_grad_norm(self,trainer:"pl.Trainer")->None:iftrainer.track_grad_norm==-1:returnkwargs={}iflen(trainer.loggers)==1:kwargs["group_separator"]=trainer.loggers[0].group_separatorgrad_norm_dict=grad_norm(trainer.lightning_module,trainer.track_grad_norm,**kwargs)ifgrad_norm_dict:prev_fx=trainer.lightning_module._current_fx_nametrainer.lightning_module._current_fx_name="on_before_optimizer_step"trainer.lightning_module.log_grad_norm(grad_norm_dict)trainer.lightning_module._current_fx_name=prev_fxdef_clip_gradients(self,model:Union["pl.LightningModule",Module],optimizer:Steppable,optimizer_idx:int,clip_val:Optional[Union[int,float]]=None,gradient_clip_algorithm:Optional[GradClipAlgorithmType]=None,)->None:ifnotisinstance(model,pl.LightningModule)ornotmodel.automatic_optimization:# the configuration validator disallows clipping on manualreturnmodel.trainer._call_lightning_module_hook("configure_gradient_clipping",optimizer,optimizer_idx,gradient_clip_val=clip_val,gradient_clip_algorithm=gradient_clip_algorithm,)
[docs]defclip_gradients(self,optimizer:Optimizer,clip_val:Union[int,float]=0.0,gradient_clip_algorithm:GradClipAlgorithmType=GradClipAlgorithmType.NORM,)->None:"""Clips the gradients."""ifclip_val<=0:returnifgradient_clip_algorithm==GradClipAlgorithmType.VALUE:self.clip_grad_by_value(optimizer,clip_val)elifgradient_clip_algorithm==GradClipAlgorithmType.NORM:self.clip_grad_by_norm(optimizer,clip_val)
[docs]defclip_grad_by_value(self,optimizer:Optimizer,clip_val:Union[int,float])->None:"""Clip gradients by value."""parameters=self.main_params(optimizer)torch.nn.utils.clip_grad_value_(parameters,clip_value=clip_val)
[docs]defclip_grad_by_norm(self,optimizer:Optimizer,clip_val:Union[int,float])->None:"""Clip gradients by norm."""parameters=self.main_params(optimizer)torch.nn.utils.clip_grad_norm_(parameters,clip_val)
[docs]defdispatch(self,trainer:"pl.Trainer")->None:"""Hook to do something when ``Strategy.dispatch()`` gets called."""
[docs]@contextlib.contextmanagerdeftrain_step_context(self)->Generator[None,None,None]:"""A contextmanager for the training step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdefval_step_context(self)->Generator[None,None,None]:"""A contextmanager for the validation step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdeftest_step_context(self)->Generator[None,None,None]:"""A contextmanager for the test step."""withself.forward_context():yield
[docs]@contextlib.contextmanagerdefpredict_step_context(self)->Generator[None,None,None]:"""A contextmanager for the predict step."""withself.forward_context():yield
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.