Shortcuts

TransformerEnginePrecision

class lightning.fabric.plugins.precision.TransformerEnginePrecision(dtype=None, recipe=None, replace_layers=None)[source]

Bases: Precision

Plugin for training with fp8 precision via nvidia’s Transformer Engine.

Warning

This is an experimental feature.

Parameters:
  • dtype (Optional[dtype]) – The weights dtype to use.

  • recipe (Union[Mapping[str, Any], DelayedScaling, None]) – Recipe for the DelayedScaling configuration. In dict format or the dataclass format.

  • replace_layers (Optional[bool]) – Whether to replace Linear and LayerNorm layers automatically with their Transformer Engine alternatives. Note that they don’t subclass the torch equivalents so checks like isinstance(l, torch.nn.Linear) will not pass.

Note

Support for FP8 in the linear layers with this plugin is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction.

convert_input(data)[source]

Convert model inputs (forward) to the floating point precision type of this plugin.

This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32).

Return type:

Any

convert_module(module)[source]

Convert the module parameters to the precision type this plugin handles.

This is optional and depends on the precision limitations during optimization.

Return type:

Module

convert_output(data)[source]

Convert outputs to the floating point precision type expected after model’s forward.

This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32).

Return type:

Any

forward_context()[source]

A contextmanager for managing model forward/training_step/evaluation_step/predict_step.

Return type:

ContextManager

module_init_context()[source]

Instantiate module parameters or tensors in the precision type this plugin handles.

This is optional and depends on the precision limitations during optimization.

Return type:

ContextManager

tensor_init_context()[source]

Controls how tensors get created (device, dtype).

Return type:

ContextManager