StochasticWeightAveraging¶
- class pytorch_lightning.callbacks.StochasticWeightAveraging(swa_epoch_start=0.8, swa_lrs=None, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[source]¶
- Bases: - pytorch_lightning.callbacks.base.Callback- Implements the Stochastic Weight Averaging (SWA) Callback to average a model. - Stochastic Weight Averaging was proposed in - Averaging Weights Leads to Wider Optima and Better Generalizationby Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).- This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s - swa_utilspackage.- For a SWA explanation, please take a look here. - Warning - StochasticWeightAveragingis in beta and subject to change.- Warning - StochasticWeightAveragingis currently not supported for multiple optimizers/schedulers.- Warning - StochasticWeightAveragingis currently only supported on every epoch.- See also how to enable it directly on the Trainer - Parameters
- swa_epoch_start¶ ( - Union[- int,- float]) – If provided as int, the procedure will start from the- swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from- int(swa_epoch_start * max_epochs)epoch
- swa_lrs¶ ( - Union[- float,- List[- float],- None]) –- The SWA learning rate to use: - None. Use the current learning rate of the optimizer at the time the SWA procedure starts.
- float. Use this value for all parameter groups of the optimizer.
- List[float]. A list values for each parameter group of the optimizer.
 
- annealing_epochs¶ ( - int) – number of epochs in the annealing phase (default: 10)
- Specifies the annealing strategy (default: “cos”): - "cos". For cosine annealing.
- "linear"For linear annealing
 
- avg_fn¶ ( - Optional[- Callable[[- Tensor,- Tensor,- LongTensor],- FloatTensor]]) – the averaging function used to update the parameters; the function must take in the current value of the- AveragedModelparameter, the current value of- modelparameter and the number of models already averaged; if None, equally weighted average is used (default:- None)
- device¶ ( - Union[- device,- str,- None]) – if provided, the averaged model will be stored on the- device. When None is provided, it will infer the device from- pl_module. (default:- "cpu")
 
 - static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97. - Return type
- FloatTensor
 
 - on_before_accelerator_backend_setup(trainer, pl_module)[source]¶
- Called before accelerator is being setup. 
 - on_train_epoch_end(trainer, *args)[source]¶
- Called when the train epoch ends. - To access all batch outputs at the end of the epoch, either: - Implement training_epoch_end in the LightningModule and access outputs via the module OR 
- Cache data across train batch hooks inside the callback implementation to post-process in this hook. 
 
 - reset_batch_norm_and_save_state(pl_module)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154. 
 - reset_momenta()[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165. 
 - static update_parameters(average_model, model, n_averaged, avg_fn)[source]¶
- Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.