As per the docs, SWA can be
implemented as a callback
swa_callback = StochasticWeightAveraging(swa_lrs=5e-4, swa_epoch_start=1)
trainer = pl.Trainer(callbacks=[swa_callback, ...]
or as a parameter to Trainer
trainer = pl.Trainer(..., stochastic_weight_avg = True, ...)
At the end of the training, I was expecting to models to be saved, a model trained by the trainer and an averaged model generated by SWA. However, I see only one model.
How can I get the averaged model generated by SWA?
I was also wondering about this, for the lack of documentation.
The way it works is as follows:
Once training finishes the averaged weights from SWA will be transferred to the original lightning module. So if you save a checkpoint after training finished via
trainer.save_checkpoint() the weights you save will correspond to the averaged weights from SWA.
Checkpointing will also save the states of SWA (including the average model state dict), once SWA started. You can get it via loading the checkpoint file with
ckpt = torch.load(<path-to-ckpt>) and then find the state dict within