PrecisionPlugin¶
- class pytorch_lightning.plugins.precision.PrecisionPlugin[source]¶
Bases:
pytorch_lightning.plugins.base_plugin.Plugin
,pytorch_lightning.core.hooks.CheckpointHooks
Base class for all plugins handling the precision-specific parts of the training. The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
- backward(model, closure_loss, optimizer, *args, **kwargs)[source]¶
Performs the actual backpropagation
- clip_gradients(optimizer, clip_val, gradient_clip_algorithm=<GradClipAlgorithmType.NORM: 'norm'>, model=None)[source]¶
Clips the gradients
- Return type
- connect(model, optimizers, lr_schedulers)[source]¶
Connects this plugin to the accelerator and the training process
- master_params(optimizer)[source]¶
The master params of the model. Returns the plain model params here. Maybe different in other precision plugins.
- post_backward(model, closure_loss)[source]¶
Run after precision plugin executes backward
- Parameters
model¶ (
LightningModule
) – the model to be optimizedclosure_loss¶ (
Tensor
) – the loss value obtained from the closure
- Return type
- post_optimizer_step(optimizer, optimizer_idx)[source]¶
Hook to do something after each optimizer step.
- Return type