Source code for pytorch_lightning.plugins.precision.native_amp
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import contextmanager
from typing import Any, Callable, Dict, Generator, Optional, Union
import torch
from torch import Tensor
from torch.optim import LBFGS, Optimizer
import pytorch_lightning as pl
from lightning_lite.accelerators.cuda import _patch_cuda_is_available
from lightning_lite.utilities.types import Optimizable
from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10, AMPType, GradClipAlgorithmType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_10:
    from torch import autocast as new_autocast
else:
    from torch.cuda.amp import autocast as old_autocast
[docs]class NativeMixedPrecisionPlugin(PrecisionPlugin):
    """Plugin for Native Mixed Precision (AMP) training with ``torch.autocast``.
    Args:
        precision: Whether to use ``torch.float16`` (``16``) or ``torch.bfloat16`` (``'bf16'``).
        device: The device for ``torch.autocast``.
        scaler: An optional :class:`torch.cuda.amp.GradScaler` to use.
    """
    backend = AMPType.NATIVE
    def __init__(
        self, precision: Union[str, int], device: str, scaler: Optional[torch.cuda.amp.GradScaler] = None
    ) -> None:
        super().__init__()
        if precision == "bf16" and not _TORCH_GREATER_EQUAL_1_10:
            raise MisconfigurationException(
                "To use bfloat16 with native amp you must install torch greater or equal to 1.10."
            )
        if scaler is None and precision == 16:
            with _patch_cuda_is_available():
                # if possible, we defer CUDA initialization to support strategies that will attempt forks
                scaler = torch.cuda.amp.GradScaler()
        if scaler is not None and precision == "bf16":
            raise MisconfigurationException(f"`precision='bf16'` does not use a scaler, found {scaler}.")
        self.precision = precision
        self.device = device
        self.scaler = scaler
[docs]    def pre_backward(self, tensor: Tensor, module: "pl.LightningModule") -> Tensor:  # type: ignore[override]
        if self.scaler is not None:
            tensor = self.scaler.scale(tensor)
        return super().pre_backward(tensor, module)
[docs]    def optimizer_step(  # type: ignore[override]
        self,
        optimizer: Optimizable,
        model: "pl.LightningModule",
        optimizer_idx: int,
        closure: Callable[[], Any],
        **kwargs: Any,
    ) -> Any:
        if self.scaler is None:
            # skip scaler logic, as bfloat16 does not require scaler
            return super().optimizer_step(
                optimizer, model=model, optimizer_idx=optimizer_idx, closure=closure, **kwargs
            )
        if isinstance(optimizer, LBFGS):
            raise MisconfigurationException(
                f"Native AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
            )
        closure_result = closure()
        if not _optimizer_handles_unscaling(optimizer):
            # Unscaling needs to be performed here in case we are going to apply gradient clipping.
            # Optimizers that perform unscaling in their `.step()` method are not supported (e.g., fused Adam).
            # Note: `unscale` happens after the closure is executed, but before the `on_before_optimizer_step` hook.
            self.scaler.unscale_(optimizer)
        self._after_closure(model, optimizer, optimizer_idx)
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if not model.automatic_optimization or not skipped_backward:
            # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
            step_output = self.scaler.step(optimizer, **kwargs)
            self.scaler.update()
            return step_output
        return closure_result
[docs]    def clip_gradients(
        self,
        optimizer: Optimizer,
        clip_val: Union[int, float] = 0.0,
        gradient_clip_algorithm: GradClipAlgorithmType = GradClipAlgorithmType.NORM,
    ) -> None:
        if clip_val > 0 and _optimizer_handles_unscaling(optimizer):
            raise RuntimeError(
                f"The current optimizer, {type(optimizer).__qualname__}, does not allow for gradient clipping"
                " because it performs unscaling of gradients internally. HINT: Are you using a 'fused' optimizer?"
            )
        super().clip_gradients(optimizer=optimizer, clip_val=clip_val, gradient_clip_algorithm=gradient_clip_algorithm)
    def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
        if _TORCH_GREATER_EQUAL_1_10:
            # the dtype could be automatically inferred but we need to manually set it due to a bug upstream
            # https://github.com/pytorch/pytorch/issues/67233
            return new_autocast(self.device, dtype=torch.bfloat16 if self.precision == "bf16" else torch.half)
        return old_autocast()
[docs]    @contextmanager
    def forward_context(self) -> Generator[None, None, None]:
        """Enable autocast context."""
        with self.autocast_context_manager():
            yield
[docs]    def state_dict(self) -> Dict[str, Any]:
        if self.scaler is not None:
            return self.scaler.state_dict()
        return {}
[docs]    def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
        if self.scaler is not None:
            self.scaler.load_state_dict(state_dict)
def _optimizer_handles_unscaling(optimizer: Any) -> bool:
    """Determines whether a PyTorch optimizer handles unscaling gradients in the step method rather than through the
    :class:`torch.cuda.amp.GradScaler`.
    Since, the current implementation of this function checks a PyTorch internal variable on the optimizer, the return
    value will only be reliable for built-in PyTorch optimizers.
    """
    return getattr(optimizer, "_step_supports_amp_scaling", False)