Source code for pytorch_lightning.callbacks.stochastic_weight_avg
# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.r"""Stochastic Weight Averaging Callback^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^"""fromcopyimportdeepcopyfromtypingimportAny,Callable,cast,Dict,List,Optional,Unionimporttorchfromtorchimportnn,Tensorfromtorch.optim.swa_utilsimportSWALRimportpytorch_lightningasplfromlightning_fabric.utilities.typesimportLRSchedulerfrompytorch_lightning.callbacks.callbackimportCallbackfrompytorch_lightning.strategiesimportDDPFullyShardedStrategy,DeepSpeedStrategyfrompytorch_lightning.strategies.fully_sharded_nativeimportDDPFullyShardedNativeStrategyfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_info,rank_zero_warnfrompytorch_lightning.utilities.typesimportLRSchedulerConfig_AVG_FN=Callable[[Tensor,Tensor,Tensor],Tensor]
[docs]classStochasticWeightAveraging(Callback):def__init__(self,swa_lrs:Union[float,List[float]],swa_epoch_start:Union[int,float]=0.8,annealing_epochs:int=10,annealing_strategy:str="cos",avg_fn:Optional[_AVG_FN]=None,device:Optional[Union[torch.device,str]]=torch.device("cpu"),):r""" Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018). This documentation is highly inspired by PyTorch's work on SWA. The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. For a SWA explanation, please take a look `here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_. .. warning:: ``StochasticWeightAveraging`` is in beta and subject to change. .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>` Arguments: swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. - ``List[float]``. A list values for each parameter group of the optimizer. swa_epoch_start: If provided as int, the procedure will start from the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch annealing_epochs: number of epochs in the annealing phase (default: 10) annealing_strategy: Specifies the annealing strategy (default: "cos"): - ``"cos"``. For cosine annealing. - ``"linear"`` For linear annealing avg_fn: the averaging function used to update the parameters; the function must take in the current value of the :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: ``None``) device: if provided, the averaged model will be stored on the ``device``. When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) """err_msg="swa_epoch_start should be a >0 integer or a float between 0 and 1."ifisinstance(swa_epoch_start,int)andswa_epoch_start<1:raiseMisconfigurationException(err_msg)ifisinstance(swa_epoch_start,float)andnot(0<=swa_epoch_start<=1):raiseMisconfigurationException(err_msg)wrong_type=notisinstance(swa_lrs,(float,list))wrong_float=isinstance(swa_lrs,float)andswa_lrs<=0wrong_list=isinstance(swa_lrs,list)andnotall(lr>0andisinstance(lr,float)forlrinswa_lrs)ifwrong_typeorwrong_floatorwrong_list:raiseMisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats")ifavg_fnisnotNoneandnotcallable(avg_fn):raiseMisconfigurationException("The `avg_fn` should be callable.")ifdeviceisnotNoneandnotisinstance(device,(torch.device,str)):raiseMisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")self.n_averaged:Optional[Tensor]=Noneself._swa_epoch_start=swa_epoch_startself._swa_lrs=swa_lrsself._annealing_epochs=annealing_epochsself._annealing_strategy=annealing_strategyself._avg_fn=avg_fnorself.avg_fnself._device=deviceself._model_contains_batch_norm:Optional[bool]=Noneself._average_model:Optional["pl.LightningModule"]=Noneself._initialized=Falseself._swa_scheduler:Optional[LRScheduler]=Noneself._scheduler_state:Optional[Dict]=Noneself._init_n_averaged=0self._latest_update_epoch=-1self.momenta:Dict[nn.modules.batchnorm._BatchNorm,Optional[float]]={}self._max_epochs:int@propertydefswa_start(self)->int:assertisinstance(self._swa_epoch_start,int)returnmax(self._swa_epoch_start-1,0)# 0-based@propertydefswa_end(self)->int:returnself._max_epochs-1# 0-based@staticmethoddefpl_module_contains_batch_norm(pl_module:"pl.LightningModule")->bool:returnany(isinstance(module,nn.modules.batchnorm._BatchNorm)formoduleinpl_module.modules())
[docs]defsetup(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",stage:str)->None:ifisinstance(trainer.strategy,(DDPFullyShardedStrategy,DDPFullyShardedNativeStrategy,DeepSpeedStrategy)):raiseMisconfigurationException("SWA does not currently support sharded models.")# copy the model before moving it to accelerator device.withpl_module._prevent_trainer_and_dataloaders_deepcopy():self._average_model=deepcopy(pl_module)
[docs]defon_fit_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:iflen(trainer.optimizers)!=1:raiseMisconfigurationException("SWA currently works with 1 `optimizer`.")iflen(trainer.lr_scheduler_configs)>1:raiseMisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")asserttrainer.max_epochsisnotNoneifisinstance(self._swa_epoch_start,float):self._swa_epoch_start=int(trainer.max_epochs*self._swa_epoch_start)self._model_contains_batch_norm=self.pl_module_contains_batch_norm(pl_module)self._max_epochs=trainer.max_epochsifself._model_contains_batch_norm:# virtually increase max_epochs to perform batch norm update on latest epoch.asserttrainer.fit_loop.max_epochsisnotNonetrainer.fit_loop.max_epochs+=1ifself._scheduler_stateisnotNone:self._clear_schedulers(trainer)
[docs]defon_train_epoch_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:if(notself._initialized)and(self.swa_start<=trainer.current_epoch<=self.swa_end):self._initialized=True# move average model to request device.assertself._average_modelisnotNoneself._average_model=self._average_model.to(self._deviceorpl_module.device)optimizer=trainer.optimizers[0]ifisinstance(self._swa_lrs,float):self._swa_lrs=[self._swa_lrs]*len(optimizer.param_groups)forlr,groupinzip(self._swa_lrs,optimizer.param_groups):group["initial_lr"]=lrasserttrainer.max_epochsisnotNoneself._swa_scheduler=cast(LRScheduler,SWALR(optimizer,swa_lr=self._swa_lrs,# type: ignore[arg-type]anneal_epochs=self._annealing_epochs,anneal_strategy=self._annealing_strategy,last_epoch=trainer.max_epochsifself._annealing_strategy=="cos"else-1,),)ifself._scheduler_stateisnotNone:# Restore scheduler state from checkpointself._swa_scheduler.load_state_dict(self._scheduler_state)eliftrainer.current_epoch!=self.swa_start:# Log a warning if we're initializing after start without any checkpoint data,# as behaviour will be different compared to having checkpoint data.rank_zero_warn("SWA is initializing after swa_start without any checkpoint data. ""This may be caused by loading a checkpoint from an older version of PyTorch Lightning.")# We assert that there is only one optimizer on fit start, so know opt_idx is always 0default_scheduler_cfg=LRSchedulerConfig(self._swa_scheduler,opt_idx=0)assertdefault_scheduler_cfg.interval=="epoch"anddefault_scheduler_cfg.frequency==1iftrainer.lr_scheduler_configs:scheduler_cfg=trainer.lr_scheduler_configs[0]ifscheduler_cfg.interval!="epoch"orscheduler_cfg.frequency!=1:rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}")rank_zero_info(f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`"f" for `{self._swa_scheduler.__class__.__name__}`")trainer.lr_scheduler_configs[0]=default_scheduler_cfgelse:trainer.lr_scheduler_configs.append(default_scheduler_cfg)ifself.n_averagedisNone:self.n_averaged=torch.tensor(self._init_n_averaged,dtype=torch.long,device=pl_module.device)if(self.swa_start<=trainer.current_epoch<=self.swa_end)and(trainer.current_epoch>self._latest_update_epoch):assertself.n_averagedisnotNoneassertself._average_modelisnotNoneself.update_parameters(self._average_model,pl_module,self.n_averaged,self._avg_fn)self._latest_update_epoch=trainer.current_epoch# Note: No > here in case the callback is saved with the model and training continuesiftrainer.current_epoch==self.swa_end+1:# Transfer weights from average model to pl_moduleassertself._average_modelisnotNoneself.transfer_weights(self._average_model,pl_module)# Reset BatchNorm for updateself.reset_batch_norm_and_save_state(pl_module)# There is no need to perform either backward or optimizer.step as we are# performing only one pass over the train data-loader to compute activation statistics# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.trainer.num_training_batches+=1trainer.fit_loop._skip_backward=Trueself._accumulate_grad_batches=trainer.accumulate_grad_batchestrainer.accumulate_grad_batches=trainer.num_training_batches
[docs]defon_train_end(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:# the trainer increases the current epoch before this hook is calledifself._model_contains_batch_normandtrainer.current_epoch-1==self.swa_end+1:# BatchNorm epoch update. Reset statetrainer.accumulate_grad_batches=self._accumulate_grad_batchestrainer.num_training_batches-=1asserttrainer.fit_loop.max_epochsisnotNonetrainer.fit_loop.max_epochs-=1self.reset_momenta()eliftrainer.current_epoch-1==self.swa_end:# Last SWA epoch. Transfer weights from average model to pl_moduleassertself._average_modelisnotNoneself.transfer_weights(self._average_model,pl_module)
[docs]defreset_batch_norm_and_save_state(self,pl_module:"pl.LightningModule")->None:"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""self.momenta={}formoduleinpl_module.modules():ifnotisinstance(module,nn.modules.batchnorm._BatchNorm):continueassertmodule.running_meanisnotNonemodule.running_mean=torch.zeros_like(module.running_mean,device=pl_module.device,dtype=module.running_mean.dtype,)assertmodule.running_varisnotNonemodule.running_var=torch.ones_like(module.running_var,device=pl_module.device,dtype=module.running_var.dtype,)self.momenta[module]=module.momentummodule.momentum=None# type: ignore[assignment]assertmodule.num_batches_trackedisnotNonemodule.num_batches_tracked*=0
[docs]defreset_momenta(self)->None:"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""forbn_moduleinself.momenta:bn_module.momentum=self.momenta[bn_module]# type: ignore[assignment]
[docs]@staticmethoddefupdate_parameters(average_model:"pl.LightningModule",model:"pl.LightningModule",n_averaged:Tensor,avg_fn:_AVG_FN)->None:"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112."""forp_swa,p_modelinzip(average_model.parameters(),model.parameters()):device=p_swa.devicep_swa_=p_swa.detach()p_model_=p_model.detach().to(device)src=p_model_ifn_averaged==0elseavg_fn(p_swa_,p_model_,n_averaged.to(device))p_swa_.copy_(src)n_averaged+=1
[docs]@staticmethoddefavg_fn(averaged_model_parameter:Tensor,model_parameter:Tensor,num_averaged:Tensor)->Tensor:"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""returnaveraged_model_parameter+(model_parameter-averaged_model_parameter)/(num_averaged+1)
@staticmethoddef_clear_schedulers(trainer:"pl.Trainer")->None:# If we have scheduler state saved, clear the scheduler configs so that we don't try to# load state into the wrong type of schedulers when restoring scheduler checkpoint state.# We'll configure the scheduler and re-load its state in on_train_epoch_start.# Note that this relies on the callback state being restored before the scheduler state is# restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of# writing that is only True for deepspeed which is already not supported by SWA.# See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background.iftrainer.lr_scheduler_configs:assertlen(trainer.lr_scheduler_configs)==1trainer.lr_scheduler_configs.clear()def_load_average_model_state(self,model_state:Any)->None:ifself._average_modelisNone:returnself._average_model.load_state_dict(model_state)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.