Source code for lightning_fabric.plugins.precision.precision
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.importcontextlibfromtypingimportAny,Dict,Generator,Optional,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportOptimizerfromtyping_extensionsimportLiteralfromlightning_fabric.plugins.precision.utilsimport_convert_fp_tensorfromlightning_fabric.utilities.typesimport_PARAMETERS,Optimizable_PRECISION_INPUT_INT=Literal[64,32,16]_PRECISION_INPUT_STR=Literal["64","32","16","bf16"]_PRECISION_INPUT=Union[_PRECISION_INPUT_INT,_PRECISION_INPUT_STR]
[docs]classPrecision:"""Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training. """precision:_PRECISION_INPUT_STR="32"
[docs]defconvert_module(self,module:Module)->Module:"""Convert the module parameters to the precision type this plugin handles. This is optional and depends on the precision limitations during optimization. """returnmodule
[docs]@contextlib.contextmanagerdefforward_context(self)->Generator[None,None,None]:"""A contextmanager for managing model forward/training_step/evaluation_step/predict_step."""yield
[docs]defconvert_input(self,data:Tensor)->Tensor:"""Convert model inputs (forward) to the floating point precision type of this plugin. This is a no-op for tensors that are not of floating-point type or already have the desired type. """return_convert_fp_tensor(data,torch.float32)
[docs]defpre_backward(self,tensor:Tensor,module:Optional[Module])->Any:"""Runs before precision plugin executes backward. Args: tensor: The tensor that will be used for backpropagation module: The module that was involved in producing the tensor and whose parameters need the gradients """
[docs]defbackward(self,tensor:Tensor,model:Optional[Module],*args:Any,**kwargs:Any)->None:"""Performs the actual backpropagation. Args: tensor: The tensor that will be used for backpropagation model: The module that was involved in producing the tensor and whose parameters need the gradients """tensor.backward(*args,**kwargs)
[docs]defpost_backward(self,tensor:Tensor,module:Optional[Module])->Any:"""Runs after precision plugin executes backward. Args: tensor: The tensor that will be used for backpropagation module: The module that was involved in producing the tensor and whose parameters need the gradients """
[docs]defoptimizer_step(self,optimizer:Optimizable,**kwargs:Any,)->Any:"""Hook to run the optimizer step."""returnoptimizer.step(**kwargs)
[docs]defmain_params(self,optimizer:Optimizer)->_PARAMETERS:"""The main params of the model. Returns the plain model params here. Maybe different in other precision plugins. """forgroupinoptimizer.param_groups:yield fromgroup["params"]
[docs]defstate_dict(self)->Dict[str,Any]:"""Called when saving a checkpoint, implement to generate precision plugin state_dict. Returns: A dictionary containing precision plugin state. """return{}
[docs]defload_state_dict(self,state_dict:Dict[str,Any])->None:"""Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict. Args: state_dict: the precision plugin state returned by ``state_dict``. """pass
[docs]defteardown(self)->None:"""This method is called to teardown the training process. It is the right place to release memory and free other resources. """
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.