Shortcuts

StochasticWeightAveraging

class pytorch_lightning.callbacks.StochasticWeightAveraging(swa_epoch_start=0.8, swa_lrs=None, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[source]

Bases: pytorch_lightning.callbacks.base.Callback

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

Stochastic Weight Averaging was proposed in Averaging Weights Leads to Wider Optima and Better Generalization by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).

This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s swa_utils package.

For a SWA explanation, please take a look here.

Warning

StochasticWeightAveraging is in beta and subject to change.

Warning

StochasticWeightAveraging is currently not supported for multiple optimizers/schedulers.

Warning

StochasticWeightAveraging is currently only supported on every epoch.

SWA can easily be activated directly from the Trainer as follow:

Trainer(stochastic_weight_avg=True)
Parameters
  • swa_epoch_start (Union[int, float]) – If provided as int, the procedure will start from the swa_epoch_start-th epoch. If provided as float between 0 and 1, the procedure will start from int(swa_epoch_start * max_epochs) epoch

  • swa_lrs (Union[float, list, None]) – the learning rate value for all param groups together or separately for each group.

  • annealing_epochs (int) – number of epochs in the annealing phase (default: 10)

  • annealing_strategy (str) –

    Specifies the annealing strategy (default: “cos”):

    • "cos". For cosine annealing.

    • "linear" For linear annealing

  • avg_fn (Optional[Callable[[Tensor, Tensor, LongTensor], FloatTensor]]) – the averaging function used to update the parameters; the function must take in the current value of the AveragedModel parameter, the current value of model parameter and the number of models already averaged; if None, equally weighted average is used (default: None)

  • device (Union[device, str, None]) – if provided, the averaged model will be stored on the device. When None is provided, it will infer the device from pl_module. (default: "cpu")

static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97

Return type

FloatTensor

on_before_accelerator_backend_setup(trainer, pl_module)[source]

Called before accelerator is being setup

on_fit_start(trainer, pl_module)[source]

Called when fit begins

on_train_end(trainer, pl_module)[source]

Called when the train ends.

on_train_epoch_end(trainer, *args)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

reset_batch_norm_and_save_state(pl_module)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154

reset_momenta()[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165

static update_parameters(average_model, model, n_averaged, avg_fn)[source]

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112