PrecisionPlugin
- class pytorch_lightning.plugins.precision.PrecisionPlugin[source]
Bases:
pytorch_lightning.core.hooks.CheckpointHooks
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
- backward(model, closure_loss, optimizer, *args, **kwargs)[source]
Performs the actual backpropagation.
- Parameters
model (
LightningModule
) – the model to be optimizedclosure_loss (
Tensor
) – the loss value obtained from the closureoptimizer (
Optional
[Optimizer
]) – current optimizer being used.None
if using manual optimization
- Return type
- clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=<GradClipAlgorithmType.NORM: 'norm'>)[source]
Clips the gradients.
- Return type
- connect(model, optimizers, lr_schedulers)[source]
Connects this plugin to the accelerator and the training process.
- dispatch(trainer)[source]
Hook to do something when
Accelerator.dispatch()
gets called.- Return type
- forward_context()[source]
A contextmanager for managing model forward/training_step/evaluation_step/predict_step.
- main_params(optimizer)[source]
The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
- master_params(optimizer)[source]
The main params of the model. :rtype:
Iterator
[Parameter
]Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Use
main_params()
instead.
- optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)[source]
Hook to run the optimizer step.
- Return type
- post_backward(model, closure_loss)[source]
Run after precision plugin executes backward.
- Parameters
model (
LightningModule
) – the model to be optimizedclosure_loss (
Tensor
) – the loss value obtained from the closure
- Return type
- post_dispatch()[source]
Hook to do something after the training/evaluation/prediction finishes.
- Return type
- pre_backward(model, closure_loss)[source]
Run before precision plugin executes backward.
- Parameters
model (
LightningModule
) – the model to be optimizedclosure_loss (
Tensor
) – the loss value obtained from the closure
- Return type
- pre_dispatch()[source]
Hook to do something before the training/evaluation/prediction starts.
- Return type
- predict_step_context()[source]
A contextmanager for the predict step.
- test_step_context()[source]
A contextmanager for the test step.
- train_step_context()[source]
A contextmanager for the training step.