I’m trying to implement SWA from this guide:
I’ve broken the example up as follows:
This bit goes in the __init__
for the LightningModule
self.swa_model = AveragedModel(self.net)
self.swa_start = 5
This goes in configure_optimizers
self.swa_scheduler = SWALR(optimizer, swa_lr=0.05)
And this bit in train_epoch_end
if self.trainer.current_epoch > self.swa_start:
torch.optim.swa_utils.update_bn(self.train_dataloader(), self.swa_model)
I’m getting this error when torch.optim.swa_utils.update_bn
is called:
RuntimeError: Expected tensor to have CPU Backend, but got tensor with CUDA Backend (while checking arguments for batch_norm_cpu)
I’m guessing I need to define self.swa_model
in such a way it gets put onto the correct device
Is there an example somewhere to use SWA with PL? Thanks!