Source code for pytorch_lightning.plugins.precision.native_amp
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,Callable,cast,Dict,Generator,Optional,UnionimporttorchfromtorchimportTensorfromtorch.optimimportLBFGS,Optimizerfromtyping_extensionsimportLiteralimportpytorch_lightningasplfromlightning_fabric.accelerators.cudaimport_patch_cuda_is_availablefromlightning_fabric.utilities.typesimportOptimizablefrompytorch_lightning.plugins.precision.precision_pluginimportPrecisionPluginfrompytorch_lightning.utilitiesimportGradClipAlgorithmTypefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation
[docs]classMixedPrecisionPlugin(PrecisionPlugin):"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """def__init__(self,precision:Literal["16",16,"bf16"],device:str,scaler:Optional[torch.cuda.amp.GradScaler]=None)->None:self.precision=cast(Literal["16","bf16"],str(precision))ifscalerisNoneandself.precision=="16":with_patch_cuda_is_available():# if possible, we defer CUDA initialization to support strategies that will attempt forksscaler=torch.cuda.amp.GradScaler()ifscalerisnotNoneandself.precision=="bf16":raiseMisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")self.device=deviceself.scaler=scaler
[docs]defoptimizer_step(# type: ignore[override]self,optimizer:Optimizable,model:"pl.LightningModule",optimizer_idx:int,closure:Callable[[],Any],**kwargs:Any,)->Any:ifself.scalerisNone:# skip scaler logic, as bfloat16 does not require scalerreturnsuper().optimizer_step(optimizer,model=model,optimizer_idx=optimizer_idx,closure=closure,**kwargs)ifisinstance(optimizer,LBFGS):raiseMisconfigurationException(f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).")closure_result=closure()ifnot_optimizer_handles_unscaling(optimizer):# Unscaling needs to be performed here in case we are going to apply gradient clipping.# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.self.scaler.unscale_(optimizer)self._after_closure(model,optimizer,optimizer_idx)skipped_backward=closure_resultisNone# in manual optimization, the closure does not return a valueifnotmodel.automatic_optimizationornotskipped_backward:# note: the scaler will skip the `optimizer.step` if nonfinite gradients are foundstep_output=self.scaler.step(optimizer,**kwargs)self.scaler.update()returnstep_outputreturnclosure_result
[docs]defclip_gradients(self,optimizer:Optimizer,clip_val:Union[int,float]=0.0,gradient_clip_algorithm:GradClipAlgorithmType=GradClipAlgorithmType.NORM,)->None:ifclip_val>0and_optimizer_handles_unscaling(optimizer):raiseRuntimeError(f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?")super().clip_gradients(optimizer=optimizer,clip_val=clip_val,gradient_clip_algorithm=gradient_clip_algorithm)
defautocast_context_manager(self)->torch.autocast:# the dtype could be automatically inferred but we need to manually set it due to a bug upstream# https://github.com/pytorch/pytorch/issues/67233returntorch.autocast(self.device,dtype=torch.bfloat16ifself.precision=="bf16"elsetorch.half)
classNativeMixedPrecisionPlugin(MixedPrecisionPlugin):backend="native"def__init__(self,*args:Any,**kwargs:Any)->None:rank_zero_deprecation("The `NativeMixedPrecisionPlugin` class has been renamed in v1.9.0 and will be removed in"" v2.0.0. Please use `pytorch_lightning.plugins.MixedPrecisionPlugin` instead.")super().__init__(*args,**kwargs)def_optimizer_handles_unscaling(optimizer:Any)->bool:"""Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the :class:`torch.cuda.amp.GradScaler`. Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return value will only be reliable for built-in PyTorch optimizers. """returngetattr(optimizer,"_step_supports_amp_scaling",False)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.