Source code for lightning_fabric.plugins.precision.native_amp
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,cast,Dict,Generator,OptionalimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportLBFGSfromtyping_extensionsimportLiteralfromlightning_fabric.accelerators.cudaimport_patch_cuda_is_availablefromlightning_fabric.plugins.precision.precisionimportPrecisionfromlightning_fabric.plugins.precision.utilsimport_convert_fp_tensorfromlightning_fabric.utilities.typesimportOptimizable
[docs]classMixedPrecision(Precision):"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """def__init__(self,precision:Literal["16",16,"bf16"],device:str,scaler:Optional[torch.cuda.amp.GradScaler]=None)->None:self.precision=cast(Literal["16","bf16"],str(precision))ifscalerisNoneandself.precision=="16":with_patch_cuda_is_available():# if possible, we defer CUDA initialization to support strategies that will attempt forksscaler=torch.cuda.amp.GradScaler()ifscalerisnotNoneandself.precision=="bf16":raiseValueError(f"`precision='bf16'` does not use a scaler, found {scaler}.")self.device=deviceself.scaler=scaler
[docs]defoptimizer_step(self,optimizer:Optimizable,**kwargs:Any,)->Any:ifself.scalerisNone:# skip scaler logic, as bfloat16 does not require scalerreturnsuper().optimizer_step(optimizer,**kwargs)ifisinstance(optimizer,LBFGS):raiseTypeError("Native AMP and the LBFGS optimizer are not compatible.")# note: the scaler will skip the `optimizer.step` if nonfinite gradients are foundstep_output=self.scaler.step(optimizer,**kwargs)self.scaler.update()returnstep_output
def_autocast_context_manager(self)->torch.autocast:# the dtype could be automatically inferred but we need to manually set it due to a bug upstream# https://github.com/pytorch/pytorch/issues/67233returntorch.autocast(self.device,dtype=torch.bfloat16ifself.precision=="bf16"elsetorch.half)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.