Takeaways
Learn how quantization reduces the memory footprint of models like Llama 2 as much as 4x!Introduction
The aim of quantization is to reduce the memory usage of the model parameters by using lower precision types than your typical float32 or (b)float16. Using lower bit widths like 8-bit and 4-bit uses less memory compared to float32 (32-bit) and (b)float16 (16-bit). The quantization procedure does not simply trim the number of bits used, but compresses the values to reduce the amount of information lost.
Using quantization to compress models that have billions of parameters like Llama 2 or SDXL makes deployment on edge devices with less memory capacity possible. Thankfully, Lightning Fabric makes quantization as easy as setting a mode flag!
from lightning_fabric import Fabric
from lightning_fabric.plugins import BitsandbytesPrecision
# all available quantization modes
# "nf4", "nf4-dq", "fp4", "fp4-dq", "int8", "int8-training"
mode = "nf4"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)
model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers
Learning about quantization is an absolute must to move models from idea to training and production at the edge. We’ll cover 8-bit, 4-bit, and double quantization below.
8-bit Quantization
8-bit quantization is discussed in the popular paper 8-bit Optimizers via Block-wise Quantization and was introduced in FP8 Formats for Deep Learning. As stated in the original paper, 8-bit quantization was the natural progression after 16-bit precision. Although it was the natural progression, the implementation was not as simple as moving from FP32 to FP16 – as those two floating point types share the same representation scheme and 8-bit does not.
8-bit quantization requires a new representation scheme, and this new scheme allows for fewer numbers to be represented than FP16 or FP32. This means model performance may be affected when using quantization, so it is good to be aware of this trade-off. Additionally, model performance should be evaluated in its quantized form if the weights will be used on an edge device that requires quantization.
Lightning Fabric can use 8-bit quantization by setting the mode flag to int8-training for training, or int8 for inference.
from lightning_fabric import Fabric
from lightning_fabric.plugins import BitsandbytesPrecision
# available 8-bit quantization modes
# ("int8", "int8-training")
mode = "int8"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)
model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers
Just as 8-bit quantization is the natural progression from 16-bit precision, 4-bit quantization is the next smallest representation scheme. Let’s talk about 4-bit quantization in the following sections.
4-bit Quantization
4-bit quantization is discussed in the popular paper QLoRA: Efficient Finetuning of Quantized LLMs. QLoRA is a finetuning method that uses 4-bit quantization. The paper introduces this finetuning technique and demonstrates how it can be used to “finetune a 65B parameter model on a single 48GB GPU while preserving full 16-bit finetuning task performance” by using the NF4 (normal float) format.
Lightning Fabric can use 4-bit quantization by setting the mode flag to either nf4 or fp4.
from lightning_fabric import Fabric
from lightning_fabric.plugins import BitsandbytesPrecision
# available 4-bit quantization modes
# ("nf4", "fp4")
mode = "nf4"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)
model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers
Double Quantization
Double quantization exists as an extra 4-bit quantization setting introduced alongside NF4 in QLoRA: Efficient Finetuning of Quantized LLMs. Double quantization works by quantizing the quantization constants that are internal to bitsandbytes’ procedures.
Lightning Fabric can use 4-bit double quantization by setting the mode flag to either nf4-dq or fp4-dq.
from lightning_fabric import Fabric
from lightning_fabric.plugins import BitsandbytesPrecision
# available 4-bit double quantization modes
# ("nf4-dq", "fp4-dq")
mode = "nf4-dq"
plugin = BitsandbytesPrecision(mode=mode)
fabric = Fabric(plugins=plugin)
model = CustomModule() # your PyTorch model
model = fabric.setup_module(model) # quantizes the layers
Conclusion
Quantization is a must for most production systems given that edge devices and consumer grade hardware typically require models of a much smaller memory footprint than more powerful hardware such as NVIDIA’s A100 80GB. Learning about this technique will enable a better understanding of deployment of LLMs like a Llama 2 and SDXL, and requirements for edge devices in robotics, vehicles, and other systems.
Still have questions?
We have an amazing community and team of core engineers ready to answer your questions. So, join us on Discourse or Discord. See you there!
Resources and References
Introduction to Quantization
Introduction to Quantization and API Summary
Quantization in Practice
Post Training Quantization
Quantization in Lightning Fabric
FP8 Formats for Deep Learning
8-bit Optimizers via Block-wise Quantization
QLoRA: Efficient Finetuning of Quantized LLMs
GPTQ: Accurate Post-Training Quantization for Generative Pre-trained Transformers
TensorFloat-32 in the A100 GPU Accelerates AI Training, HPC up to 20x
Automatic Mixed Precision for Deep Learning