Source code for pytorch_lightning.callbacks.quantization
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Quantization
^^^^^^^^^^^^
"""
import copy
import functools
from typing import Any, Callable, Dict, Optional, Sequence, Union
import torch
from torch import Tensor
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8
if _TORCH_GREATER_EQUAL_1_8:
    from torch.quantization import FakeQuantizeBase
else:
    # For torch 1.6 and 1.7.
    from torch.quantization import FakeQuantize as FakeQuantizeBase
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.utilities import _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.exceptions import MisconfigurationException
if _TORCH_GREATER_EQUAL_1_10:
    from torch.ao.quantization.qconfig import QConfig
else:
    from torch.quantization import QConfig
def wrap_qat_forward_context(
    quant_cb, model: "pl.LightningModule", func: Callable, trigger_condition: Optional[Union[Callable, int]] = None
) -> Callable:
    """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
    compatibility Moreover this version has the (de)quantization conditional as it may not be needed for the
    training all the time."""
    # todo: consider using registering hook before/after forward
    @functools.wraps(func)
    def wrapper(data) -> Any:
        _is_func_true = isinstance(trigger_condition, Callable) and trigger_condition(model.trainer)
        _is_count_true = isinstance(trigger_condition, int) and quant_cb._forward_calls < trigger_condition
        _quant_run = trigger_condition is None or _is_func_true or _is_count_true
        # apply custom trigger
        if _quant_run:
            quant_cb._forward_calls += 1
            data = model.quant(data)
        data = func(data)
        # apply custom trigger
        if _quant_run:
            data = model.dequant(data)
        return data
    return wrapper
def wrap_quantize_forward_context(model: "pl.LightningModule", func: Callable) -> Callable:
    """Decorator to wrap forward path as it is needed to quantize inputs and dequantize outputs for in/out
    compatibility."""
    # todo: consider using registering hook before/after forward
    @functools.wraps(func)
    def wrapper(data) -> Any:
        data = model.quant(data)
        data = func(data)
        data = model.dequant(data)
        return data
    return wrapper
def _recursive_hasattr(obj: Any, attribs: str, state: bool = True) -> bool:
    """recursive check if model has some layers denoted with '.'."""
    if "." in attribs:
        attrib, attribs = attribs.split(".", 1)
        if hasattr(obj, attrib):
            return _recursive_hasattr(getattr(obj, attrib), attribs, state)
        return False
    return state and hasattr(obj, attribs)
[docs]class QuantizationAwareTraining(Callback):
    """Quantization allows speeding up inference and decreasing memory requirements by performing computations and
    storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native
    PyTorch API so for more information see `PyTorch Quantization`_.
    .. warning:: ``QuantizationAwareTraining`` is in beta and subject to change.
    Args:
        qconfig: quantization configuration:
            - 'fbgemm' for server inference.
            - 'qnnpack' for mobile inference.
            - a custom `torch.quantization.QConfig`_.
        observer_type: allows switching between ``MovingAverageMinMaxObserver`` as "average" (default)
            and ``HistogramObserver`` as "histogram" which is more computationally expensive.
        collect_quantization: count or custom function to collect quantization statistics:
            - ``None`` (deafult). The quantization observer is called in each module forward
                (useful for collecting extended statistic when useing image/data augmentation).
            - ``int``. Use to set a fixed number of calls, starting from the beginning.
            - ``Callable``. Custom function with single trainer argument.
                See this example to trigger only the last epoch:
                .. code-block:: python
                    def custom_trigger_last(trainer):
                        return trainer.current_epoch == (trainer.max_epochs - 1)
                    QuantizationAwareTraining(collect_quantization=custom_trigger_last)
        modules_to_fuse: allows you fuse a few layers together as shown in
            `diagram <https://pytorch.org/docs/stable/quantization.html#quantization-aware-training>`_
            to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.
        input_compatible: preserve quant/dequant layers. This allows to feat any input as to the original model,
            but break compatibility to torchscript and export with ``torch.save``.
        quantize_on_fit_end: perform the quantization in `on_fit_end`.
            Note that once converted, the model cannot be put in training mode again.
        observer_enabled_stages: allow fake-quantization modules' observers to do calibration during provided stages:
            - ``'train'``: the observers can do calibration during training.
            - ``'validate'``: the observers can do calibration during validating.
              Note that we don't disable observers during the sanity check as the model hasn't been calibrated with
              training data yet. After the sanity check, the fake-quantization modules are restored to initial states.
            - ``'test'``: the observers can do calibration during testing.
            - ``'predict'``: the observers can do calibration during predicting.
            Note that we only handle observers belonging to fake-quantization modules. When ``qconfig`` is a ``str`` and
            ``observer_type`` is ``'histogram'``, the observers won't belong to any fake-quantization modules and will
            not be controlled by the callback.
    .. _PyTorch Quantization: https://pytorch.org/docs/stable/quantization.html#quantization-aware-training
    .. _torch.quantization.QConfig: https://pytorch.org/docs/stable/torch.quantization.html#torch.quantization.QConfig
    """
    OBSERVER_TYPES = ("histogram", "average")
    OBSERVER_STAGES = ("train", "validate", "test", "predict")
    def __init__(
        self,
        qconfig: Union[str, QConfig] = "fbgemm",
        observer_type: str = "average",
        collect_quantization: Optional[Union[int, Callable]] = None,
        modules_to_fuse: Optional[Sequence] = None,
        input_compatible: bool = True,
        quantize_on_fit_end: bool = True,
        observer_enabled_stages: Sequence[str] = ("train",),
    ) -> None:
        _valid_qconf_str = isinstance(qconfig, str) and qconfig in torch.backends.quantized.supported_engines
        if not isinstance(qconfig, QConfig) and not _valid_qconf_str:
            raise MisconfigurationException(
                f"Unsupported qconfig: f{qconfig}.\nTry one of defaults: {torch.backends.quantized.supported_engines}"
            )
        self._qconfig = qconfig
        if observer_type not in self.OBSERVER_TYPES:
            raise MisconfigurationException(
                f'Unsupported observer type "{observer_type}", allowed are {self.OBSERVER_TYPES}.'
            )
        self._observer_type = observer_type
        if collect_quantization is not None and not isinstance(collect_quantization, (int, Callable)):
            raise MisconfigurationException(
                f'Unsupported `collect_quantization` "{collect_quantization}", allowed are `int` or `Callable`.'
            )
        self._collect_quantization = collect_quantization
        self.modules_to_fuse = modules_to_fuse
        self._input_compatible = input_compatible
        self._convert_on_fit_end = quantize_on_fit_end
        observer_enabled_stages = set(observer_enabled_stages)
        unsupported_stages = observer_enabled_stages - set(self.OBSERVER_STAGES)
        if unsupported_stages:
            raise MisconfigurationException(
                f'Unsupported stages "{tuple(sorted(unsupported_stages))}", allowed are {self.OBSERVER_STAGES}.'
            )
        self._observer_disabled_stages = set(self.OBSERVER_STAGES) - observer_enabled_stages
        self._forward_calls = 0
        self._fake_quant_to_initial_state_dict = {}
        self._last_fake_quant_to_observer_enabled = {}
    def _check_feasible_fuse(self, model: "pl.LightningModule") -> bool:
        if not self.modules_to_fuse:
            return False
        for group in self.modules_to_fuse:
            if not all(_recursive_hasattr(model, m) for m in group):
                raise MisconfigurationException(
                    f"You have requested to fuse {group} but one or more of them is not your model attributes"
                )
        return True
    def _collect_observer_enabled(self) -> Dict[FakeQuantizeBase, Tensor]:
        return {
            fake_quant: fake_quant.observer_enabled.clone() for fake_quant in self._fake_quant_to_initial_state_dict
        }
    def _disable_observer(self, pl_module: "pl.LightningModule") -> None:
        self._last_fake_quant_to_observer_enabled = self._collect_observer_enabled()
        pl_module.apply(torch.quantization.disable_observer)
    def _restore_last_observer_enabled(self) -> None:
        for fake_quant, observer_enabled in self._last_fake_quant_to_observer_enabled.items():
            fake_quant.observer_enabled.copy_(observer_enabled)
[docs]    def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        # QuantStub converts tensors from floating point to quantized
        pl_module.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point
        pl_module.dequant = torch.quantization.DeQuantStub()
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        self.__module_forward = pl_module.forward
        pl_module.forward = wrap_qat_forward_context(
            quant_cb=self, model=pl_module, func=pl_module.forward, trigger_condition=self._collect_quantization
        )
        # attach a global qconfig, which contains information about what kind
        # of observers to attach. Use 'fbgemm' for server inference
        if isinstance(self._qconfig, str):
            if self._observer_type == "histogram":
                pl_module.qconfig = torch.quantization.get_default_qconfig(self._qconfig)
            elif self._observer_type == "average":
                # version=None corresponds to using FakeQuantize rather than
                # FusedMovingAvgObsFakeQuantize which was introduced in PT1.10
                # details in https://github.com/pytorch/pytorch/issues/64564
                extra_kwargs = dict(version=None) if _TORCH_GREATER_EQUAL_1_10 else {}
                pl_module.qconfig = torch.quantization.get_default_qat_qconfig(self._qconfig, **extra_kwargs)
        elif isinstance(self._qconfig, QConfig):
            pl_module.qconfig = self._qconfig
        if self._check_feasible_fuse(pl_module):
            torch.quantization.fuse_modules(pl_module, self.modules_to_fuse, inplace=True)
        # Prepare the model for QAT. This inserts observers and fake_quants in
        # the model that will observe weight and activation tensors during calibration.
        torch.quantization.prepare_qat(pl_module, inplace=True)
        fake_quants = tuple(module for module in pl_module.modules() if isinstance(module, FakeQuantizeBase))
        self._fake_quant_to_initial_state_dict = {
            fake_quant: copy.deepcopy(fake_quant.state_dict()) for fake_quant in fake_quants
        }
[docs]    def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if not self._convert_on_fit_end:
            pl_module.forward = self.__module_forward
            return
        pl_module.eval()
        # Convert the observed model to a quantized model. This does several things:
        # quantizes the weights, computes and stores the scale and bias value to be
        # used with each activation tensor, fuses modules where appropriate,
        # and replaces key operators with quantized implementations.
        torch.quantization.convert(pl_module, inplace=True)
        # check we shall preserve wrapper
        if self._input_compatible:
            pl_module.forward = wrap_quantize_forward_context(model=pl_module, func=self.__module_forward)
        else:
            pl_module.forward = self.__module_forward
[docs]    def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "train" in self._observer_disabled_stages:
            self._disable_observer(pl_module)
[docs]    def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "train" in self._observer_disabled_stages:
            self._restore_last_observer_enabled()
[docs]    def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "validate" in self._observer_disabled_stages and not trainer.sanity_checking:
            # ``torch.quantization.MovingAveragePerChannelMinMaxObserver`` and ``torch.quantization.HistogramObserver``
            # need to see at least one batch to infer the shapes of quantization ``scale`` and ``zero_point``. So we
            # don't disable observers during the sanity check so that they can infer the shapes of quantization
            # parameters with validation data.
            self._disable_observer(pl_module)
[docs]    def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "validate" in self._observer_disabled_stages:
            if trainer.sanity_checking:
                for fake_quant, state_dict in self._fake_quant_to_initial_state_dict.items():
                    fake_quant.load_state_dict(state_dict)
            else:
                self._restore_last_observer_enabled()
[docs]    def on_test_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "test" in self._observer_disabled_stages:
            self._disable_observer(pl_module)
[docs]    def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "test" in self._observer_disabled_stages:
            self._restore_last_observer_enabled()
[docs]    def on_predict_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "predict" in self._observer_disabled_stages:
            self._disable_observer(pl_module)
[docs]    def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
        if "predict" in self._observer_disabled_stages:
            self._restore_last_observer_enabled()