Pruning and Quantization¶
Pruning and Quantization are techniques to compress model size for deployment, allowing inference speed up and energy saving without significant accuracy losses.
Pruning¶
Warning
This is an experimental feature.
Pruning is a technique which focuses on eliminating some of the model weights to reduce the model size and decrease inference requirements.
Pruning has been shown to achieve significant efficiency improvements while minimizing the drop in model performance (prediction quality). Model pruning is recommended for cloud endpoints, deploying models on edge devices, or mobile inference (among others).
To enable pruning during training in Lightning, simply pass in the ModelPruning
callback to the Lightning Trainer. PyTorch’s native pruning implementation is used under the hood.
This callback supports multiple pruning functions: pass any torch.nn.utils.prune function as a string to select which weights to prune (random_unstructured, RandomStructured, etc) or implement your own by subclassing BasePruningMethod.
from lightning.pytorch.callbacks import ModelPruning
# set the amount to be the fraction of parameters to prune
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=0.5)])
You can also perform iterative pruning, apply the lottery ticket hypothesis, and more!
def compute_amount(epoch):
# the sum of all returned values need to be smaller than 1
if epoch == 10:
return 0.5
elif epoch == 50:
return 0.25
elif 75 < epoch < 99:
return 0.01
# the amount can be also be a callable
trainer = Trainer(callbacks=[ModelPruning("l1_unstructured", amount=compute_amount)])