Source code for pytorch_lightning.plugins.precision.native_amp
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.fromcontextlibimportcontextmanagerfromtypingimportAny,Callable,Dict,Generator,Optional,UnionimporttorchfromtorchimportTensorfromtorch.nnimportModulefromtorch.optimimportLBFGS,Optimizerimportpytorch_lightningasplfrompytorch_lightning.plugins.precision.mixedimportMixedPrecisionPluginfrompytorch_lightning.utilitiesimport_TORCH_GREATER_EQUAL_1_10,AMPTypefrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionif_TORCH_GREATER_EQUAL_1_10:fromtorchimportautocastasnew_autocastelse:fromtorch.cuda.ampimportautocastasold_autocast
[docs]classNativeMixedPrecisionPlugin(MixedPrecisionPlugin):"""Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``. Args: precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """backend=AMPType.NATIVEdef__init__(self,precision:Union[str,int],device:str,scaler:Optional[torch.cuda.amp.GradScaler]=None)->None:super().__init__()ifprecision=="bf16"andnot_TORCH_GREATER_EQUAL_1_10:raiseMisconfigurationException("To use bfloat16 with native amp you must install torch greater or equal to 1.10.")ifscalerisNoneandprecision==16:scaler=torch.cuda.amp.GradScaler()ifscalerisnotNoneandprecision=="bf16":raiseMisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")self.precision=precisionself.device=deviceself.scaler=scaler
[docs]defoptimizer_step(self,model:Optional[Union["pl.LightningModule",Module]],optimizer:Optimizer,optimizer_idx:int,closure:Callable[[],Any],**kwargs:Any,)->Any:ifself.scalerisNone:# skip scaler logic, as bfloat16 does not require scalerreturnsuper().optimizer_step(model,optimizer,optimizer_idx,closure,**kwargs)ifisinstance(optimizer,LBFGS):raiseMisconfigurationException(f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx}).")closure_result=closure()# `unscale` after the closure is executed but before the `on_before_optimizer_step` hook.self.scaler.unscale_(optimizer)self._after_closure(model,optimizer,optimizer_idx)skipped_backward=closure_resultisNone# in manual optimization, the closure does not return a valueifnotisinstance(model,pl.LightningModule)ornotmodel.automatic_optimizationornotskipped_backward:# note: the scaler will skip the `optimizer.step` if nonfinite gradients are foundstep_output=self.scaler.step(optimizer,**kwargs)self.scaler.update()returnstep_outputreturnclosure_result
defautocast_context_manager(self)->Union["old_autocast","new_autocast"]:if_TORCH_GREATER_EQUAL_1_10:# the dtype could be automatically inferred but we need to manually set it due to a bug upstream# https://github.com/pytorch/pytorch/issues/67233returnnew_autocast(self.device,dtype=torch.bfloat16ifself.precision=="bf16"elsetorch.half)returnold_autocast()
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.