TransformerEnginePrecisionPlugin¶
- class lightning.pytorch.plugins.precision.TransformerEnginePrecisionPlugin(dtype=None, recipe=None, replace_layers=None)[source]¶
Bases:
PrecisionPlugin
,TransformerEnginePrecision
Plugin for training with fp8 precision via nvidia’s Transformer Engine.
Warning
This is an experimental feature.
- Parameters:
recipe¶ (
Union
[Mapping
[str
,Any
],DelayedScaling
,None
]) – Recipe for the DelayedScaling configuration. In dict format or the dataclass format.replace_layers¶ (
Optional
[bool
]) – Whether to replaceLinear
andLayerNorm
layers automatically with their Transformer Engine alternatives. Note that they don’t subclass the torch equivalents so checks likeisinstance(l, torch.nn.Linear)
will not pass.
Note
Support for FP8 in the linear layers with this plugin is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction.