PrecisionPlugin
- class pytorch_lightning.plugins.precision.PrecisionPlugin[source]
Bases:
lightning_fabric.plugins.precision.precision.Precision
,pytorch_lightning.core.hooks.CheckpointHooks
Base class for all plugins handling the precision-specific parts of the training.
The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.
- backward(tensor, model, optimizer, optimizer_idx, *args, **kwargs)[source]
Performs the actual backpropagation.
- Parameters
tensor (
Tensor
) – the loss value obtained from the closuremodel (
LightningModule
) – the model to be optimizedoptimizer (
Optional
[Steppable
]) – current optimizer being used.None
if using manual optimizationoptimizer_idx (
Optional
[int
]) – the index of the current optimizer.None
if using manual optimization*args – Positional arguments intended for the actual function that performs the backward, like
backward()
.
- Return type
- clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[source]
Clips the gradients.
- Return type
- connect(model, optimizers, lr_schedulers)[source]
Connects this plugin to the accelerator and the training process.
- dispatch(trainer)[source]
Hook to do something when
Strategy.dispatch()
gets called.- Return type
- optimizer_step(optimizer, model, optimizer_idx, closure, **kwargs)[source]
Hook to run the optimizer step.
- Return type
- post_backward(tensor, module)[source]
Runs after precision plugin executes backward.
- Parameters
tensor (
Tensor
) – The tensor that will be used for backpropagationmodule (
LightningModule
) – The module that was involved in producing the tensor and whose parameters need the gradients
- Return type
- pre_backward(tensor, module)[source]
Runs before precision plugin executes backward.
- Parameters
tensor (
Tensor
) – The tensor that will be used for backpropagationmodule (
LightningModule
) – The module that was involved in producing the tensor and whose parameters need the gradients
- Return type
- predict_step_context()[source]
A contextmanager for the predict step.
- test_step_context()[source]
A contextmanager for the test step.
- train_step_context()[source]
A contextmanager for the training step.