Hi,
I’m trying to understand how to use StochasticWeightAveraging properly. Have a first look at the code, it looks to me that the average model is only saved back to the actual model after training is finished at max_epochs
. I want to check the implications of this:
(a) Does this mean any metrics calculated during the validation steps will not use the average model, but the raw model being trained?
(b) And if I load a model using the usual checkpoint mechanics for any checkpoint before max_epochs
will this be the raw training model and not the average model?
Thanks for any help on this!
EDIT: just in case it’s useful to understand why I’m wondering this, I don’t tend to set a max_epochs
and usually stop training manually when I see overfitting or convergance. For the validation part of the question, it would be nice to monitor the performance of SWA in my dashboard (e.g. wandb).