StochasticWeightAveraging¶
- class lightning.pytorch.callbacks.StochasticWeightAveraging(swa_lrs, swa_epoch_start=0.8, annealing_epochs=10, annealing_strategy='cos', avg_fn=None, device=device(type='cpu'))[source]¶
Bases:
Callback
Implements the Stochastic Weight Averaging (SWA) Callback to average a model.
Stochastic Weight Averaging was proposed in
Averaging Weights Leads to Wider Optima and Better Generalization
by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018).This documentation is highly inspired by PyTorch’s work on SWA. The callback arguments follow the scheme defined in PyTorch’s
swa_utils
package.For a SWA explanation, please take a look here.
Warning
This is an experimental feature.
Warning
StochasticWeightAveraging
is currently not supported for multiple optimizers/schedulers.Warning
StochasticWeightAveraging
is currently only supported on every epoch.See also how to enable it directly on the Trainer
- Parameters:
swa_lrs¶ (
Union
[float
,List
[float
]]) –The SWA learning rate to use:
float
. Use this value for all parameter groups of the optimizer.List[float]
. A list values for each parameter group of the optimizer.
swa_epoch_start¶ (
Union
[int
,float
]) – If provided as int, the procedure will start from theswa_epoch_start
-th epoch. If provided as float between 0 and 1, the procedure will start fromint(swa_epoch_start * max_epochs)
epochannealing_epochs¶ (
int
) – number of epochs in the annealing phase (default: 10)Specifies the annealing strategy (default: “cos”):
"cos"
. For cosine annealing."linear"
For linear annealing
avg_fn¶ (
Optional
[Callable
[[Tensor
,Tensor
,Tensor
],Tensor
]]) – the averaging function used to update the parameters; the function must take in the current value of theAveragedModel
parameter, the current value ofmodel
parameter and the number of models already averaged; if None, equally weighted average is used (default:None
)device¶ (
Union
[device
,str
,None
]) – if provided, the averaged model will be stored on thedevice
. When None is provided, it will infer the device frompl_module
. (default:"cpu"
)
- static avg_fn(averaged_model_parameter, model_parameter, num_averaged)[source]¶
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.
- Return type:
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_train_epoch_end(trainer, *args)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
lightning.pytorch.core.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type:
- reset_batch_norm_and_save_state(pl_module)[source]¶
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.
- Return type:
- reset_momenta()[source]¶
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.
- Return type:
- setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
- Return type:
- static update_parameters(average_model, model, n_averaged, avg_fn)[source]¶
Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.
- Return type: