StochasticWeightAveraging validation logging and checkpoints


I’m trying to understand how to use StochasticWeightAveraging properly. Have a first look at the code, it looks to me that the average model is only saved back to the actual model after training is finished at max_epochs. I want to check the implications of this:

(a) Does this mean any metrics calculated during the validation steps will not use the average model, but the raw model being trained?
(b) And if I load a model using the usual checkpoint mechanics for any checkpoint before max_epochs will this be the raw training model and not the average model?

Thanks for any help on this!

EDIT: just in case it’s useful to understand why I’m wondering this, I don’t tend to set a max_epochs and usually stop training manually when I see overfitting or convergance. For the validation part of the question, it would be nice to monitor the performance of SWA in my dashboard (e.g. wandb).

Hi, I’d also like to know this.

I’m also not entirely sure why both swa_epoch_start and annealing_epochs is necessary. If swa_epoch_start=0.8 in a run 10 epoch training run, doesn’t that mean annealing_epochs must equal 2?

ad a) yes, it will always calculate metrics on the “raw” model as you called it! I sketched out how you could validate on either of the models here: Support checkpoint save and load with Stochastic Weight Averaging by adamreeve · Pull Request #9938 · Lightning-AI/lightning · GitHub

ad b) Checkpointing will also save the states of SWA (including the average model state dict), once SWA started. You can get it via loading the checkpoint file with ckpt = torch.load(<path-to-ckpt>) and then find the state dict within ckpt['callbacks']['StochasticWeightAveraging']['average_model_state'].