Source code for lightning.fabric.plugins.precision.transformer_engine

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from contextlib import ExitStack
from typing import TYPE_CHECKING, Any, ContextManager, Literal, Mapping, Optional, Union

import torch
from lightning_utilities import apply_to_collection
from lightning_utilities.core.imports import RequirementCache
from torch import Tensor
from typing_extensions import override

from lightning.fabric.plugins.precision.precision import Precision
from lightning.fabric.plugins.precision.utils import (
    _ClassReplacementContextManager,
    _convert_fp_tensor,
    _DtypeContextManager,
)
from lightning.fabric.utilities.rank_zero import rank_zero_info, rank_zero_warn

if TYPE_CHECKING:
    from transformer_engine.common.recipe import DelayedScaling

_TRANSFORMER_ENGINE_AVAILABLE = RequirementCache("transformer_engine>=0.11.0")
log = logging.getLogger(__name__)


[docs]class TransformerEnginePrecision(Precision): """Plugin for training with fp8 precision via nvidia's `Transformer Engine <https://docs.nvidia.com/deeplearning/transformer-engine>`__. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. Args: weights_dtype: The weights dtype to use. recipe: Recipe for the DelayedScaling `configuration <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html#transformer_engine.common.recipe.DelayedScaling>`__. In dict format or the dataclass format. replace_layers: Whether to replace ``Linear`` and ``LayerNorm`` layers automatically with their Transformer Engine alternatives. Note that they don't subclass the torch equivalents so checks like ``isinstance(l, torch.nn.Linear)`` will not pass. fallback_compute_dtype: The compute dtype to use for operations that don't support fp8 autocast. Defaults to the same as ``weights_dtype``. .. note:: Support for FP8 in the linear layers with this plugin is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction. """ precision: Literal["transformer-engine", "transformer-engine-float16"] = "transformer-engine" def __init__( self, *, weights_dtype: torch.dtype, recipe: Optional[Union[Mapping[str, Any], "DelayedScaling"]] = None, replace_layers: Optional[bool] = None, fallback_compute_dtype: Optional[torch.dtype] = None, ) -> None: if not _TRANSFORMER_ENGINE_AVAILABLE: raise ModuleNotFoundError(str(_TRANSFORMER_ENGINE_AVAILABLE)) from transformer_engine.common.recipe import DelayedScaling if recipe is None: recipe = DelayedScaling() elif isinstance(recipe, Mapping): recipe = dict(recipe) # copy if "fp8_format" in recipe: from transformer_engine.common.recipe import Format recipe["fp8_format"] = getattr(Format, recipe["fp8_format"]) recipe = DelayedScaling(**recipe) self.weights_dtype = weights_dtype self.recipe = recipe self.replace_layers = replace_layers self.fallback_compute_dtype = fallback_compute_dtype or weights_dtype
[docs] @override def convert_module(self, module: torch.nn.Module) -> torch.nn.Module: # avoid converting if any is found. assume the user took care of it if any("transformer_engine.pytorch" in m.__module__ for m in module.modules()): if self.replace_layers is True: # info level because this is expected with `init_module` rank_zero_info( "`TransformerEnginePrecision(replace_layers=True)` is set but the model already contains" " TransformerEngine layers. Skipping" ) elif self.replace_layers in (None, True): _convert_layers(module) module = module.to(dtype=self.weights_dtype) return module
[docs] @override def tensor_init_context(self) -> ContextManager: return _DtypeContextManager(self.weights_dtype)
[docs] @override def module_init_context(self) -> ContextManager: dtype_ctx = self.tensor_init_context() stack = ExitStack() if self.replace_layers: import transformer_engine.pytorch as te context_manager = _ClassReplacementContextManager({ "torch.nn.Linear": te.Linear, "torch.nn.LayerNorm": te.LayerNorm, }) stack.enter_context(context_manager) stack.enter_context(dtype_ctx) return stack
[docs] @override def forward_context(self) -> ContextManager: dtype_ctx = _DtypeContextManager(self.weights_dtype) fallback_autocast_ctx = torch.autocast(device_type="cuda", dtype=self.fallback_compute_dtype) import transformer_engine.pytorch as te autocast_ctx = te.fp8_autocast(enabled=True, fp8_recipe=self.recipe) stack = ExitStack() stack.enter_context(dtype_ctx) # enable an outer fallback autocast for operations that do not support fp8 stack.enter_context(fallback_autocast_ctx) stack.enter_context(autocast_ctx) return stack
[docs] @override def convert_input(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=self.weights_dtype)
[docs] @override def convert_output(self, data: Any) -> Any: return apply_to_collection(data, function=_convert_fp_tensor, dtype=Tensor, dst_type=torch.get_default_dtype())
def _convert_layers(module: torch.nn.Module) -> None: import transformer_engine.pytorch as te for name, child in module.named_children(): if isinstance(child, torch.nn.Linear): if child.in_features % 8 != 0 or child.out_features % 16 != 0: # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html#FP8-autocasting rank_zero_warn( "Support for FP8 in the linear layers with this plugin is currently limited to" " tensors with shapes where the dimensions are divisible by 8 and 16 respectively." f" The layer {name!r} does not fit this criteria. You might want to add padding to your inputs." ) continue has_bias = child.bias is not None replacement = te.Linear(child.in_features, child.out_features, bias=has_bias) replacement.weight.data = child.weight.data.clone() if has_bias: replacement.bias.data = child.bias.data.clone() log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") module.__setattr__(name, replacement) elif isinstance(child, torch.nn.LayerNorm): replacement = te.LayerNorm(child.normalized_shape[0], eps=child.eps) replacement.weight.data = child.weight.data.clone() replacement.bias.data = child.bias.data.clone() log.debug(f"Replacing layer {name!r} with Transformer Engine equivalent") module.__setattr__(name, replacement) else: # there are other transformer engine layers that we could convert but require fusion. full list at: # https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/pytorch.html _convert_layers(child)