ApexMixedPrecisionPlugin¶
- class pytorch_lightning.plugins.precision.ApexMixedPrecisionPlugin(amp_level='O2')[source]¶
 Bases:
pytorch_lightning.plugins.precision.precision_plugin.PrecisionPluginMixed Precision Plugin based on Nvidia/Apex (https://github.com/NVIDIA/apex)
- backward(tensor, model, optimizer, *args, **kwargs)[source]¶
 Run before precision plugin executes backward.
- Parameters:
 model¶ (
LightningModule) – the model to be optimizedoptimizer¶ (
Optional[Optimizable]) – current optimizer being used.Noneif using manual optimization*args¶ (
Any) – Positional arguments intended for the actual function that performs the backward, likebackward().**kwargs¶ (
Any) – Keyword arguments for the same purpose as*args.
- Return type:
 
- load_state_dict(state_dict)[source]¶
 Called when loading a checkpoint, implement to reload precision plugin state given precision plugin state_dict.
- main_params(optimizer)[source]¶
 The main params of the model.
Returns the plain model params here. Maybe different in other precision plugins.