TransformerEnginePrecision¶
- class lightning.fabric.plugins.precision.TransformerEnginePrecision(dtype=None, recipe=None, replace_layers=None)[source]¶
Bases:
PrecisionPlugin for training with fp8 precision via nvidia’s Transformer Engine.
Warning
This is an experimental feature.
- Parameters:
recipe¶ (
Union[Mapping[str,Any],DelayedScaling,None]) – Recipe for the DelayedScaling configuration. In dict format or the dataclass format.replace_layers¶ (
Optional[bool]) – Whether to replaceLinearandLayerNormlayers automatically with their Transformer Engine alternatives. Note that they don’t subclass the torch equivalents so checks likeisinstance(l, torch.nn.Linear)will not pass.
Note
Support for FP8 in the linear layers with this plugin is currently limited to tensors with shapes where the dimensions are divisible by 8 and 16 respectively. You might want to add padding to your inputs to conform to this restriction.
- convert_input(data)[source]¶
Convert model inputs (forward) to the floating point precision type of this plugin.
This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32).
- Return type:
- convert_module(module)[source]¶
Convert the module parameters to the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
- Return type:
- convert_output(data)[source]¶
Convert outputs to the floating point precision type expected after model’s forward.
This is a no-op in the base precision plugin, since we assume the data already has the desired type (default is torch.float32).
- Return type:
- forward_context()[source]¶
A contextmanager for managing model forward/training_step/evaluation_step/predict_step.
- Return type:
- module_init_context()[source]¶
Instantiate module parameters or tensors in the precision type this plugin handles.
This is optional and depends on the precision limitations during optimization.
- Return type: