Source code for lightning.pytorch.plugins.precision.amp
# Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,Callable,Dict,Generator,Literal,Optional,UnionimporttorchfromtorchimportTensorfromtorch.optimimportLBFGS,Optimizerimportlightning.pytorchasplfromlightning.fabric.accelerators.cudaimport_patch_cuda_is_availablefromlightning.fabric.plugins.precision.ampimport_optimizer_handles_unscalingfromlightning.fabric.utilities.typesimportOptimizablefromlightning.pytorch.plugins.precision.precisionimportPrecisionfromlightning.pytorch.utilitiesimportGradClipAlgorithmTypefromlightning.pytorch.utilities.exceptionsimportMisconfigurationException
[docs]classMixedPrecision(Precision):"""Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """def__init__(self,precision:Literal["16-mixed","bf16-mixed"],device:str,scaler:Optional[torch.cuda.amp.GradScaler]=None,)->None:ifprecisionnotin("16-mixed","bf16-mixed"):raiseValueError(f"`Passed `{type(self).__name__}(precision={precision!r})`."f" Precision must be '16-mixed' or 'bf16-mixed'.")self.precision=precisionifscalerisNoneandself.precision=="16-mixed":with_patch_cuda_is_available():# if possible, we defer CUDA initialization to support strategies that will attempt forksscaler=torch.cuda.amp.GradScaler()ifscalerisnotNoneandself.precision=="bf16-mixed":raiseMisconfigurationException(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.")self.device=deviceself.scaler=scaler
[docs]defoptimizer_step(# type: ignore[override]self,optimizer:Optimizable,model:"pl.LightningModule",closure:Callable[[],Any],**kwargs:Any,)->Any:ifself.scalerisNone:# skip scaler logic, as bfloat16 does not require scalerreturnsuper().optimizer_step(optimizer,model=model,closure=closure,**kwargs)ifisinstance(optimizer,LBFGS):raiseMisconfigurationException("AMP and the LBFGS optimizer are not compatible.")closure_result=closure()# If backward was skipped in automatic optimization (return None), unscaling is not neededskip_unscaling=closure_resultisNoneandmodel.automatic_optimizationifnot_optimizer_handles_unscaling(optimizer)andnotskip_unscaling:# Unscaling needs to be performed here in case we are going to apply gradient clipping.# Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).# Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.self.scaler.unscale_(optimizer)self._after_closure(model,optimizer)# in manual optimization, the closure does not return a valueifnotskip_unscaling:# note: the scaler will skip the `optimizer.step` if nonfinite gradients are foundstep_output=self.scaler.step(optimizer,**kwargs)self.scaler.update()returnstep_outputreturnclosure_result
[docs]defclip_gradients(self,optimizer:Optimizer,clip_val:Union[int,float]=0.0,gradient_clip_algorithm:GradClipAlgorithmType=GradClipAlgorithmType.NORM,)->None:ifclip_val>0and_optimizer_handles_unscaling(optimizer):raiseRuntimeError(f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"" because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?")super().clip_gradients(optimizer=optimizer,clip_val=clip_val,gradient_clip_algorithm=gradient_clip_algorithm)
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.