Source code for lightning.fabric.plugins.precision.amp

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, ContextManager, Dict, Literal, Optional

import torch
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from torch.nn import Module
from torch.optim import LBFGS, Optimizer
from typing_extensions import override

from lightning.fabric.accelerators.cuda import _patch_cuda_is_available
from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import _convert_fp_tensor
from lightning.fabric.utilities.types import Optimizable


[docs]class MixedPrecision(Precision): """Plugin for Automatic Mixed Precision (AMP) training with ``torch.autocast``. Args: precision: Whether to use ``torch.float16`` (``'16-mixed'``) or ``torch.bfloat16`` (``'bf16-mixed'``). device: The device for ``torch.autocast``. scaler: An optional :class:`torch.cuda.amp.GradScaler` to use. """ def __init__( self, precision: Literal["16-mixed", "bf16-mixed"], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None, ) -> None: if precision not in ("16-mixed", "bf16-mixed"): raise ValueError( f"Passed `{type(self).__name__}(precision={precision!r})`." " Precision must be '16-mixed' or 'bf16-mixed'." ) self.precision = precision if scaler is None and self.precision == "16-mixed": with _patch_cuda_is_available(): # if possible, we defer CUDA initialization to support strategies that will attempt forks scaler = torch.cuda.amp.GradScaler() if scaler is not None and self.precision == "bf16-mixed": raise ValueError(f"`precision='bf16-mixed'` does not use a scaler, found {scaler}.") self.device = device self.scaler = scaler self._desired_input_dtype = torch.bfloat16 if self.precision == "bf16-mixed" else torch.float16
[docs] @override def forward_context(self) -> ContextManager: return torch.autocast(self.device, dtype=self._desired_input_dtype)
[docs] @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self._desired_input_dtype)
[docs] @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
[docs] @override def backward(self, tensor: Tensor, model: Optional[Module], *args: Any, **kwargs: Any) -> None: if self.scaler is not None: tensor = self.scaler.scale(tensor) super().backward(tensor, model, *args, **kwargs)
[docs] @override def optimizer_step( self, optimizer: Optimizable, **kwargs: Any, ) -> Any: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler return super().optimizer_step(optimizer, **kwargs) if isinstance(optimizer, LBFGS): raise TypeError("AMP and the LBFGS optimizer are not compatible.") # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() return step_output
[docs] @override def state_dict(self) -> Dict[str, Any]: if self.scaler is not None: return self.scaler.state_dict() return {}
[docs] @override def load_state_dict(self, state_dict: Dict[str, Any]) -> None: if self.scaler is not None: self.scaler.load_state_dict(state_dict)
@override def unscale_gradients(self, optimizer: Optimizer) -> None: scaler = self.scaler if scaler is not None: if _optimizer_handles_unscaling(optimizer): raise NotImplementedError("Gradient clipping is not implemented for optimizers handling the unscaling.") scaler.unscale_(optimizer)
def _optimizer_handles_unscaling(optimizer: Any) -> bool: """Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the :class:`torch.cuda.amp.GradScaler`. Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return value will only be reliable for built-in PyTorch optimizers. """ return getattr(optimizer, "_step_supports_amp_scaling", False)