Shortcuts

PrecisionPlugin

class pytorch_lightning.plugins.precision.PrecisionPlugin[source]

Bases: pytorch_lightning.core.hooks.CheckpointHooks

Base class for all plugins handling the precision-specific parts of the training.

The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.

backward(model, closure_loss, optimizer, *args, **kwargs)[source]

Performs the actual backpropagation.

Parameters
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

  • optimizer (Optional[Optimizer]) – current optimizer being used. None if using manual optimization

Return type

None

clip_grad_by_norm(optimizer, clip_val)[source]

Clip gradients by norm.

Return type

None

clip_grad_by_value(optimizer, clip_val)[source]

Clip gradients by value.

Return type

None

clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=<GradClipAlgorithmType.NORM: 'norm'>)[source]

Clips the gradients.

Return type

None

connect(model, optimizers, lr_schedulers)[source]

Connects this plugin to the accelerator and the training process.

Return type

Tuple[Module, List[Optimizer], List[Any]]

dispatch(trainer)[source]

Hook to do something when Accelerator.dispatch() gets called.

Return type

None

forward_context()[source]

A contextmanager for managing model forward/training_step/evaluation_step/predict_step.

Return type

Generator[None, None, None]

main_params(optimizer)[source]

The main params of the model.

Returns the plain model params here. Maybe different in other precision plugins.

Return type

Iterator[Parameter]

master_params(optimizer)[source]

The main params of the model. :rtype: Iterator[Parameter]

Deprecated since version v1.5: This method is deprecated in v1.5 and will be removed in v1.6. Use main_params() instead.

optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)[source]

Hook to run the optimizer step.

Return type

None

post_backward(model, closure_loss)[source]

Run after precision plugin executes backward.

Parameters
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

Return type

Tensor

post_dispatch()[source]

Hook to do something after the training/evaluation/prediction finishes.

Return type

None

pre_backward(model, closure_loss)[source]

Run before precision plugin executes backward.

Parameters
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

Return type

Tensor

pre_dispatch()[source]

Hook to do something before the training/evaluation/prediction starts.

Return type

None

predict_step_context()[source]

A contextmanager for the predict step.

Return type

Generator[None, None, None]

test_step_context()[source]

A contextmanager for the test step.

Return type

Generator[None, None, None]

train_step_context()[source]

A contextmanager for the training step.

Return type

Generator[None, None, None]

val_step_context()[source]

A contextmanager for the validation step.

Return type

Generator[None, None, None]