DeepSpeedPrecisionPlugin
- class lightning.pytorch.plugins.precision.DeepSpeedPrecisionPlugin(precision)[source]
Bases:
PrecisionPlugin
Precision plugin for DeepSpeed integration.
Warning
This is an experimental feature.
- Parameters:
precision (
Literal
['32-true'
,'16-mixed'
,'bf16-mixed'
]) – Full precision (32), half precision (16) or bfloat16 precision (bf16).- Raises:
ValueError – If unsupported
precision
is provided.
- backward(tensor, model, optimizer, *args, **kwargs)[source]
Performs back-propagation using DeepSpeed’s engine.
- Parameters:
tensor (
Tensor
) – the loss tensormodel (
LightningModule
) – the model to be optimizedoptimizer (
Optional
[Steppable
]) – ignored for DeepSpeed*args (
Any
) – additional positional arguments for thedeepspeed.DeepSpeedEngine.backward()
call**kwargs (
Any
) – additional keyword arguments for thedeepspeed.DeepSpeedEngine.backward()
call
- Return type:
- clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[source]
DeepSpeed handles gradient clipping internally.
- Return type: