Shortcuts

PrecisionPlugin

class pytorch_lightning.plugins.precision.PrecisionPlugin[source]

Bases: pytorch_lightning.core.hooks.CheckpointHooks

Base class for all plugins handling the precision-specific parts of the training.

The class attribute precision must be overwritten in child classes. The default value reflects fp32 training.

backward(model, closure_loss, optimizer, optimizer_idx, *args, **kwargs)[source]

Performs the actual backpropagation.

Parameters:
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

  • optimizer (Optional[Optimizer]) – current optimizer being used. None if using manual optimization

Return type:

None

clip_grad_by_norm(optimizer, clip_val)[source]

Clip gradients by norm.

Return type:

None

clip_grad_by_value(optimizer, clip_val)[source]

Clip gradients by value.

Return type:

None

clip_gradients(optimizer, clip_val=0.0, gradient_clip_algorithm=GradClipAlgorithmType.NORM)[source]

Clips the gradients.

Return type:

None

connect(model, optimizers, lr_schedulers)[source]

Connects this plugin to the accelerator and the training process.

Return type:

Tuple[Module, List[Optimizer], List[Any]]

dispatch(trainer)[source]

Hook to do something when Strategy.dispatch() gets called.

Return type:

None

forward_context()[source]

A contextmanager for managing model forward/training_step/evaluation_step/predict_step.

Return type:

Generator[None, None, None]

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict.

Parameters:

state_dict (Dict[str, Any]) – the precision plugin state returned by state_dict.

Return type:

None

main_params(optimizer)[source]

The main params of the model.

Returns the plain model params here. Maybe different in other precision plugins.

Return type:

Iterator[Parameter]

on_load_checkpoint(checkpoint)[source]

PrecisionPlugin.on_load_checkpoint was deprecated in v1.6 and will be removed in v1.8.

Use load_state_dict instead.

Return type:

None

on_save_checkpoint(checkpoint)[source]

PrecisionPlugin.on_save_checkpoint was deprecated in v1.6 and will be removed in v1.8.

Use state_dict instead.

Return type:

None

optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)[source]

Hook to run the optimizer step.

Return type:

Any

post_backward(model, closure_loss)[source]

Run after precision plugin executes backward.

Parameters:
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

Return type:

Tensor

pre_backward(model, closure_loss)[source]

Run before precision plugin executes backward.

Parameters:
  • model (LightningModule) – the model to be optimized

  • closure_loss (Tensor) – the loss value obtained from the closure

Return type:

Tensor

predict_step_context()[source]

A contextmanager for the predict step.

Return type:

Generator[None, None, None]

state_dict()[source]

Called when saving a checkpoint, implement to generate precision plugin state_dict.

Return type:

Dict[str, Any]

Returns:

A dictionary containing precision plugin state.

teardown()[source]

This method is called to teardown the training process.

It is the right place to release memory and free other resources.

Return type:

None

test_step_context()[source]

A contextmanager for the test step.

Return type:

Generator[None, None, None]

train_step_context()[source]

A contextmanager for the training step.

Return type:

Generator[None, None, None]

val_step_context()[source]

A contextmanager for the validation step.

Return type:

Generator[None, None, None]