DeepSpeedPrecisionPlugin¶
- class pytorch_lightning.plugins.precision.DeepSpeedPrecisionPlugin(precision, amp_type=None, amp_level=None)[source]¶
Bases:
pytorch_lightning.plugins.precision.precision_plugin.PrecisionPlugin
Precision plugin for DeepSpeed integration.
- Parameters
precision¶ (
Literal
[‘32’, 32, ‘16’, 16, ‘bf16’]) – Full precision (32), half precision (16) or bfloat16 precision (bf16).- Raises
ValueError – If unsupported
precision
is provided.
- backward(tensor, model, optimizer, optimizer_idx, *args, **kwargs)[source]¶
Performs back-propagation using DeepSpeed’s engine.
- Parameters
model¶ (
LightningModule
) – the model to be optimized*args¶ – additional positional arguments for the
deepspeed.DeepSpeedEngine.backward()
call**kwargs¶ – additional keyword arguments for the
deepspeed.DeepSpeedEngine.backward()
call
- Return type