Source code for lightning.pytorch.callbacks.stochastic_weight_avg

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Stochastic Weight Averaging Callback
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""

from copy import deepcopy
from typing import Any, Callable, Dict, List, Optional, Union, cast

import torch
from torch import Tensor, nn
from torch.optim.swa_utils import SWALR
from typing_extensions import override

import lightning.pytorch as pl
from lightning.fabric.utilities.types import LRScheduler
from lightning.pytorch.callbacks.callback import Callback
from lightning.pytorch.strategies import DeepSpeedStrategy
from lightning.pytorch.strategies.fsdp import FSDPStrategy
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.rank_zero import rank_zero_info, rank_zero_warn
from lightning.pytorch.utilities.types import LRSchedulerConfig

_AVG_FN = Callable[[Tensor, Tensor, Tensor], Tensor]


[docs]class StochasticWeightAveraging(Callback): def __init__( self, swa_lrs: Union[float, List[float]], swa_epoch_start: Union[int, float] = 0.8, annealing_epochs: int = 10, annealing_strategy: str = "cos", avg_fn: Optional[_AVG_FN] = None, device: Optional[Union[torch.device, str]] = torch.device("cpu"), ): r"""Implements the Stochastic Weight Averaging (SWA) Callback to average a model. Stochastic Weight Averaging was proposed in ``Averaging Weights Leads to Wider Optima and Better Generalization`` by Pavel Izmailov, Dmitrii Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson (UAI 2018). This documentation is highly inspired by PyTorch's work on SWA. The callback arguments follow the scheme defined in PyTorch's ``swa_utils`` package. For a SWA explanation, please take a look `here <https://pytorch.org/blog/pytorch-1.6-now-includes-stochastic-weight-averaging>`_. .. warning:: This is an :ref:`experimental <versioning:Experimental API>` feature. .. warning:: ``StochasticWeightAveraging`` is currently not supported for multiple optimizers/schedulers. .. warning:: ``StochasticWeightAveraging`` is currently only supported on every epoch. See also how to :ref:`enable it directly on the Trainer <advanced/training_tricks:Stochastic Weight Averaging>` Arguments: swa_lrs: The SWA learning rate to use: - ``float``. Use this value for all parameter groups of the optimizer. - ``List[float]``. A list values for each parameter group of the optimizer. swa_epoch_start: If provided as int, the procedure will start from the ``swa_epoch_start``-th epoch. If provided as float between 0 and 1, the procedure will start from ``int(swa_epoch_start * max_epochs)`` epoch annealing_epochs: number of epochs in the annealing phase (default: 10) annealing_strategy: Specifies the annealing strategy (default: "cos"): - ``"cos"``. For cosine annealing. - ``"linear"`` For linear annealing avg_fn: the averaging function used to update the parameters; the function must take in the current value of the :class:`AveragedModel` parameter, the current value of :attr:`model` parameter and the number of models already averaged; if None, equally weighted average is used (default: ``None``) device: if provided, the averaged model will be stored on the ``device``. When None is provided, it will infer the `device` from ``pl_module``. (default: ``"cpu"``) """ err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1." if isinstance(swa_epoch_start, int) and swa_epoch_start < 1: raise MisconfigurationException(err_msg) if isinstance(swa_epoch_start, float) and not (0 <= swa_epoch_start <= 1): raise MisconfigurationException(err_msg) wrong_type = not isinstance(swa_lrs, (float, list)) wrong_float = isinstance(swa_lrs, float) and swa_lrs <= 0 wrong_list = isinstance(swa_lrs, list) and not all(lr > 0 and isinstance(lr, float) for lr in swa_lrs) if wrong_type or wrong_float or wrong_list: raise MisconfigurationException("The `swa_lrs` should a positive float, or a list of positive floats") if avg_fn is not None and not callable(avg_fn): raise MisconfigurationException("The `avg_fn` should be callable.") if device is not None and not isinstance(device, (torch.device, str)): raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}") self.n_averaged: Optional[Tensor] = None self._swa_epoch_start = swa_epoch_start self._swa_lrs = swa_lrs self._annealing_epochs = annealing_epochs self._annealing_strategy = annealing_strategy self._avg_fn = avg_fn or self.avg_fn self._device = device self._model_contains_batch_norm: Optional[bool] = None self._average_model: Optional["pl.LightningModule"] = None self._initialized = False self._swa_scheduler: Optional[LRScheduler] = None self._scheduler_state: Optional[Dict] = None self._init_n_averaged = 0 self._latest_update_epoch = -1 self.momenta: Dict[nn.modules.batchnorm._BatchNorm, Optional[float]] = {} self._max_epochs: int @property def swa_start(self) -> int: assert isinstance(self._swa_epoch_start, int) return max(self._swa_epoch_start - 1, 0) # 0-based @property def swa_end(self) -> int: return self._max_epochs - 1 # 0-based @staticmethod def pl_module_contains_batch_norm(pl_module: "pl.LightningModule") -> bool: return any(isinstance(module, nn.modules.batchnorm._BatchNorm) for module in pl_module.modules())
[docs] @override def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: str) -> None: if isinstance(trainer.strategy, (FSDPStrategy, DeepSpeedStrategy)): raise MisconfigurationException("SWA does not currently support sharded models.") # copy the model before moving it to accelerator device. self._average_model = deepcopy(pl_module)
[docs] @override def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if len(trainer.optimizers) != 1: raise MisconfigurationException("SWA currently works with 1 `optimizer`.") if len(trainer.lr_scheduler_configs) > 1: raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.") assert trainer.max_epochs is not None if isinstance(self._swa_epoch_start, float): self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start) self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module) self._max_epochs = trainer.max_epochs if self._model_contains_batch_norm: # virtually increase max_epochs to perform batch norm update on latest epoch. assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs += 1 if self._scheduler_state is not None: self._clear_schedulers(trainer)
[docs] @override def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end): self._initialized = True # move average model to request device. assert self._average_model is not None self._average_model = self._average_model.to(self._device or pl_module.device) optimizer = trainer.optimizers[0] if isinstance(self._swa_lrs, float): self._swa_lrs = [self._swa_lrs] * len(optimizer.param_groups) for lr, group in zip(self._swa_lrs, optimizer.param_groups): group["initial_lr"] = lr assert trainer.max_epochs is not None self._swa_scheduler = cast( LRScheduler, SWALR( optimizer, swa_lr=self._swa_lrs, # type: ignore[arg-type] anneal_epochs=self._annealing_epochs, anneal_strategy=self._annealing_strategy, last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1, ), ) if self._scheduler_state is not None: # Restore scheduler state from checkpoint self._swa_scheduler.load_state_dict(self._scheduler_state) elif trainer.current_epoch != self.swa_start: # Log a warning if we're initializing after start without any checkpoint data, # as behaviour will be different compared to having checkpoint data. rank_zero_warn( "SWA is initializing after swa_start without any checkpoint data. " "This may be caused by loading a checkpoint from an older version of PyTorch Lightning." ) # We assert that there is only one optimizer on fit start default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler) assert default_scheduler_cfg.interval == "epoch" assert default_scheduler_cfg.frequency == 1 if trainer.lr_scheduler_configs: scheduler_cfg = trainer.lr_scheduler_configs[0] if scheduler_cfg.interval != "epoch" or scheduler_cfg.frequency != 1: rank_zero_warn(f"SWA is currently only supported every epoch. Found {scheduler_cfg}") rank_zero_info( f"Swapping scheduler `{scheduler_cfg.scheduler.__class__.__name__}`" f" for `{self._swa_scheduler.__class__.__name__}`" ) trainer.lr_scheduler_configs[0] = default_scheduler_cfg else: trainer.lr_scheduler_configs.append(default_scheduler_cfg) if self.n_averaged is None: self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device) if (self.swa_start <= trainer.current_epoch <= self.swa_end) and ( trainer.current_epoch > self._latest_update_epoch ): assert self.n_averaged is not None assert self._average_model is not None self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn) self._latest_update_epoch = trainer.current_epoch # Note: No > here in case the callback is saved with the model and training continues if trainer.current_epoch == self.swa_end + 1: # Transfer weights from average model to pl_module assert self._average_model is not None self.transfer_weights(self._average_model, pl_module) # Reset BatchNorm for update self.reset_batch_norm_and_save_state(pl_module) # There is no need to perform either backward or optimizer.step as we are # performing only one pass over the train data-loader to compute activation statistics # Therefore, we will virtually increase the number of training batches by 1 and skip backward. trainer.fit_loop.max_batches += 1 trainer.fit_loop._skip_backward = True self._accumulate_grad_batches = trainer.accumulate_grad_batches assert isinstance(trainer.fit_loop.max_batches, int), "Iterable-style datasets are not supported" trainer.accumulate_grad_batches = trainer.fit_loop.max_batches
[docs] @override def on_train_epoch_end(self, trainer: "pl.Trainer", *args: Any) -> None: trainer.fit_loop._skip_backward = False
[docs] @override def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: # the trainer increases the current epoch before this hook is called if self._model_contains_batch_norm and trainer.current_epoch - 1 == self.swa_end + 1: # BatchNorm epoch update. Reset state trainer.accumulate_grad_batches = self._accumulate_grad_batches trainer.fit_loop.max_batches -= 1 assert trainer.fit_loop.max_epochs is not None trainer.fit_loop.max_epochs -= 1 self.reset_momenta() elif trainer.current_epoch - 1 == self.swa_end: # Last SWA epoch. Transfer weights from average model to pl_module assert self._average_model is not None self.transfer_weights(self._average_model, pl_module)
@staticmethod def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule") -> None: for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()): dst_param.detach().copy_(src_param.to(dst_param.device))
[docs] def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154.""" self.momenta = {} for module in pl_module.modules(): if not isinstance(module, nn.modules.batchnorm._BatchNorm): continue assert module.running_mean is not None module.running_mean = torch.zeros_like( module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype, ) assert module.running_var is not None module.running_var = torch.ones_like( module.running_var, device=pl_module.device, dtype=module.running_var.dtype, ) self.momenta[module] = module.momentum module.momentum = None # type: ignore[assignment] assert module.num_batches_tracked is not None module.num_batches_tracked *= 0
[docs] def reset_momenta(self) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165.""" for bn_module in self.momenta: bn_module.momentum = self.momenta[bn_module] # type: ignore[assignment]
[docs] @staticmethod def update_parameters( average_model: "pl.LightningModule", model: "pl.LightningModule", n_averaged: Tensor, avg_fn: _AVG_FN ) -> None: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L104-L112.""" for p_swa, p_model in zip(average_model.parameters(), model.parameters()): device = p_swa.device p_swa_ = p_swa.detach() p_model_ = p_model.detach().to(device) src = p_model_ if n_averaged == 0 else avg_fn(p_swa_, p_model_, n_averaged.to(device)) p_swa_.copy_(src) n_averaged += 1
[docs] @staticmethod def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor: """Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97.""" return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)
[docs] @override def state_dict(self) -> Dict[str, Any]: return { "n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(), "latest_update_epoch": self._latest_update_epoch, "scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(), "average_model_state": None if self._average_model is None else self._average_model.state_dict(), }
[docs] @override def load_state_dict(self, state_dict: Dict[str, Any]) -> None: self._init_n_averaged = state_dict["n_averaged"] self._latest_update_epoch = state_dict["latest_update_epoch"] self._scheduler_state = state_dict["scheduler_state"] self._load_average_model_state(state_dict["average_model_state"])
@staticmethod def _clear_schedulers(trainer: "pl.Trainer") -> None: # If we have scheduler state saved, clear the scheduler configs so that we don't try to # load state into the wrong type of schedulers when restoring scheduler checkpoint state. # We'll configure the scheduler and re-load its state in on_train_epoch_start. # Note that this relies on the callback state being restored before the scheduler state is # restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of # writing that is only True for deepspeed which is already not supported by SWA. # See https://github.com/Lightning-AI/lightning/issues/11665 for background. if trainer.lr_scheduler_configs: assert len(trainer.lr_scheduler_configs) == 1 trainer.lr_scheduler_configs.clear() def _load_average_model_state(self, model_state: Any) -> None: if self._average_model is None: return self._average_model.load_state_dict(model_state)