ApexMixedPrecisionPlugin
- class pytorch_lightning.plugins.precision.ApexMixedPrecisionPlugin(amp_level='O2')[source]
Bases:
pytorch_lightning.plugins.precision.mixed.MixedPrecisionPlugin
Mixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)
- backward(model, closure_loss, optimizer, *args, **kwargs)[source]
Run before precision plugin executes backward.
- Parameters
model (
LightningModule
) – the model to be optimizedclosure_loss (
Tensor
) – the loss value obtained from the closureoptimizer (
Optional
[Optimizer
]) – current optimizer being used.None
if using manual optimization
- Return type
- dispatch(trainer)[source]
Hook to do something when
Strategy.dispatch()
gets called.- Return type
- load_state_dict(state_dict)[source]
Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict.
- main_params(optimizer)[source]
The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.
- optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)[source]
Hook to run the optimizer step.
- Return type