DeepSpeedPrecisionPlugin
- class pytorch_lightning.plugins.precision.DeepSpeedPrecisionPlugin(precision, amp_type, amp_level=None)[source]
Bases:
pytorch_lightning.plugins.precision.precision_plugin.PrecisionPlugin
Precision plugin for DeepSpeed integration.
- Parameters:
precision (
Union
[str
,int
]) – Double precision (64), full precision (32), half precision (16) or bfloat16 precision (bf16).amp_type (
str
) – The mixed precision backend to use (“native” or “apex”).amp_level (
Optional
[str
]) – The optimization level to use (O1, O2, etc…). By default it will be set to “O2” ifamp_type
is set to “apex”.
- Raises:
MisconfigurationException – If using
bfloat16
precision anddeepspeed<v0.6
.ValueError – If unsupported
precision
is provided.
- backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)[source]
Performs the actual backpropagation.
- Parameters:
model (
LightningModule
) – the model to be optimizedclosure_loss (
Tensor
) – the loss value obtained from the closureoptimizer (
Optional
[Optimizer
]) – current optimizer being used.None
if using manual optimization
- Return type:
- clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[source]
DeepSpeed handles gradient clipping internally.
- Return type: