Shortcuts

QuantizationAwareTraining

class pytorch_lightning.callbacks.QuantizationAwareTraining(qconfig='fbgemm', observer_type='average', collect_quantization=None, modules_to_fuse=None, input_compatible=True, quantize_on_fit_end=True, observer_enabled_stages=('train',))[source]

Bases: pytorch_lightning.callbacks.base.Callback

Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. We use native PyTorch API so for more information see PyTorch Quantization.

Warning

QuantizationAwareTraining is in beta and subject to change.

The LightningModule is prepared for QAT training in the on_fit_start hook. Checkpoints saved during training include already collected stats to perform the Quantization conversion, but it doesn’t contain the quantized or fused model/layers. The quantization is performed in the on_fit_end hook so the model needs to be saved after training finishes if quantization is desired.

Parameters
  • qconfig (Union[str, QConfig]) –

    quantization configuration:

  • observer_type (str) – allows switching between MovingAverageMinMaxObserver as “average” (default) and HistogramObserver as “histogram” which is more computationally expensive.

  • collect_quantization (Union[Callable, int, None]) –

    count or custom function to collect quantization statistics:

    • None (default). The quantization observer is called in each module forward

      (useful for collecting extended statistic when using image/data augmentation).

    • int. Use to set a fixed number of calls, starting from the beginning.

    • Callable. Custom function with single trainer argument.

      See this example to trigger only the last epoch:

      def custom_trigger_last(trainer):
          return trainer.current_epoch == (trainer.max_epochs - 1)
      
      
      QuantizationAwareTraining(collect_quantization=custom_trigger_last)
      

  • modules_to_fuse (Optional[Sequence]) – allows you fuse a few layers together as shown in diagram to find which layer types can be fused, check https://github.com/pytorch/pytorch/pull/43286.

  • input_compatible (bool) – preserve quant/dequant layers. This allows to feat any input as to the original model, but break compatibility to torchscript and export with torch.save.

  • quantize_on_fit_end (bool) – perform the quantization in on_fit_end. Note that once converted, the model cannot be put in training mode again.

  • observer_enabled_stages (Sequence[str]) –

    allow fake-quantization modules’ observers to do calibration during provided stages:

    • 'train': the observers can do calibration during training.

    • 'validate': the observers can do calibration during validating. Note that we don’t disable observers during the sanity check as the model hasn’t been calibrated with training data yet. After the sanity check, the fake-quantization modules are restored to initial states.

    • 'test': the observers can do calibration during testing.

    • 'predict': the observers can do calibration during predicting.

    Note that we only handle observers belonging to fake-quantization modules. When qconfig is a str and observer_type is 'histogram', the observers won’t belong to any fake-quantization modules and will not be controlled by the callback.

on_fit_end(trainer, pl_module)[source]

Called when fit ends.

Return type

None

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

on_predict_end(trainer, pl_module)[source]

Called when predict ends.

Return type

None

on_predict_start(trainer, pl_module)[source]

Called when the predict begins.

Return type

None

on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type

None

on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type

None

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type

None

on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type

None

on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type

None

on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type

Dict[str, Any]

Returns

A dictionary containing callback state.