Source code for pytorch_lightning.callbacks.pruning
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.r"""ModelPruning^^^^^^^^^^^^"""importinspectimportloggingfromcopyimportdeepcopyfromfunctoolsimportpartialfromtypingimportAny,Callable,Dict,List,Optional,Sequence,Tuple,Unionimporttorchimporttorch.nn.utils.pruneaspytorch_prunefromtorchimportnnfromtyping_extensionsimportTypedDictimportpytorch_lightningasplfrompytorch_lightning.callbacks.baseimportCallbackfrompytorch_lightning.core.lightningimportLightningModulefrompytorch_lightning.utilities.apply_funcimportapply_to_collectionfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_debug,rank_zero_onlylog=logging.getLogger(__name__)_PYTORCH_PRUNING_FUNCTIONS={"ln_structured":pytorch_prune.ln_structured,"l1_unstructured":pytorch_prune.l1_unstructured,"random_structured":pytorch_prune.random_structured,"random_unstructured":pytorch_prune.random_unstructured,}_PYTORCH_PRUNING_METHOD={"ln_structured":pytorch_prune.LnStructured,"l1_unstructured":pytorch_prune.L1Unstructured,"random_structured":pytorch_prune.RandomStructured,"random_unstructured":pytorch_prune.RandomUnstructured,}_PARAM_TUPLE=Tuple[nn.Module,str]_PARAM_LIST=Sequence[_PARAM_TUPLE]_MODULE_CONTAINERS=(LightningModule,nn.Sequential,nn.ModuleList,nn.ModuleDict)class_LayerRef(TypedDict):data:nn.Modulenames:List[Tuple[int,str]]
[docs]classModelPruning(Callback):PARAMETER_NAMES=("weight","bias")def__init__(self,pruning_fn:Union[Callable,str],parameters_to_prune:_PARAM_LIST=(),parameter_names:Optional[List[str]]=None,use_global_unstructured:bool=True,amount:Union[int,float,Callable[[int],Union[int,float]]]=0.5,apply_pruning:Union[bool,Callable[[int],bool]]=True,make_pruning_permanent:bool=True,use_lottery_ticket_hypothesis:Union[bool,Callable[[int],bool]]=True,resample_parameters:bool=False,pruning_dim:Optional[int]=None,pruning_norm:Optional[int]=None,verbose:int=0,prune_on_train_epoch_end:bool=True,)->None:"""Model pruning Callback, using PyTorch's prune utilities. This callback is responsible of pruning networks parameters during training. To learn more about pruning with PyTorch, please take a look at `this tutorial <https://pytorch.org/tutorials/intermediate/pruning_tutorial.html>`_. .. warning:: ``ModelPruning`` is in beta and subject to change. .. code-block:: python parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")] trainer = Trainer( callbacks=[ ModelPruning( pruning_fn="l1_unstructured", parameters_to_prune=parameters_to_prune, amount=0.01, use_global_unstructured=True, ) ] ) When ``parameters_to_prune`` is ``None``, ``parameters_to_prune`` will contain all parameters from the model. The user can override ``filter_parameters_to_prune`` to filter any ``nn.Module`` to be pruned. Args: pruning_fn: Function from torch.nn.utils.prune module or your own PyTorch ``BasePruningMethod`` subclass. Can also be string e.g. `"l1_unstructured"`. See pytorch docs for more details. parameters_to_prune: List of tuples ``(nn.Module, "parameter_name_string")``. parameter_names: List of parameter names to be pruned from the nn.Module. Can either be ``"weight"`` or ``"bias"``. use_global_unstructured: Whether to apply pruning globally on the model. If ``parameters_to_prune`` is provided, global unstructured will be restricted on them. amount: Quantity of parameters to prune: - ``float``. Between 0.0 and 1.0. Represents the fraction of parameters to prune. - ``int``. Represents the absolute number of parameters to prune. - ``Callable``. For dynamic values. Will be called every epoch. Should return a value. apply_pruning: Whether to apply pruning. - ``bool``. Always apply it or not. - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. make_pruning_permanent: Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved. use_lottery_ticket_hypothesis: See `The lottery ticket hypothesis <https://arxiv.org/abs/1803.03635>`_: - ``bool``. Whether to apply it or not. - ``Callable[[epoch], bool]``. For dynamic values. Will be called every epoch. resample_parameters: Used with ``use_lottery_ticket_hypothesis``. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used. pruning_dim: If you are using a structured pruning method you need to specify the dimension. pruning_norm: If you are using ``ln_structured`` you need to specify the norm. verbose: Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity prune_on_train_epoch_end: whether to apply pruning at the end of the training epoch. If this is ``False``, then the check runs at the end of the validation epoch. Raises: MisconfigurationException: If ``parameter_names`` is neither ``"weight"`` nor ``"bias"``, if the provided ``pruning_fn`` is not supported, if ``pruning_dim`` is not provided when ``"unstructured"``, if ``pruning_norm`` is not provided when ``"ln_structured"``, if ``pruning_fn`` is neither ``str`` nor :class:`torch.nn.utils.prune.BasePruningMethod`, or if ``amount`` is none of ``int``, ``float`` and ``Callable``. """self._use_global_unstructured=use_global_unstructuredself._parameters_to_prune=parameters_to_pruneself._use_lottery_ticket_hypothesis=use_lottery_ticket_hypothesisself._resample_parameters=resample_parametersself._prune_on_train_epoch_end=prune_on_train_epoch_endself._parameter_names=parameter_namesorself.PARAMETER_NAMESself._global_kwargs:Dict[str,Any]={}self._original_layers:Optional[Dict[int,_LayerRef]]=Noneself._pruning_method_name:Optional[str]=Nonefornameinself._parameter_names:ifnamenotinself.PARAMETER_NAMES:raiseMisconfigurationException(f"The provided `parameter_names` name: {name} isn't in {self.PARAMETER_NAMES}")ifisinstance(pruning_fn,str):pruning_kwargs={}pruning_fn=pruning_fn.lower()ifpruning_fnnotin_PYTORCH_PRUNING_FUNCTIONS:raiseMisconfigurationException(f"The provided `pruning_fn` {pruning_fn} isn't available in PyTorch's"f" built-in functions: {list(_PYTORCH_PRUNING_FUNCTIONS.keys())} ")ifpruning_fn.endswith("_structured"):ifpruning_dimisNone:raiseMisconfigurationException("When requesting `structured` pruning, the `pruning_dim` should be provided.")ifpruning_fn=="ln_structured":ifpruning_normisNone:raiseMisconfigurationException("When requesting `ln_structured` pruning, the `pruning_norm` should be provided.")pruning_kwargs["n"]=pruning_normpruning_kwargs["dim"]=pruning_dimpruning_fn=self._create_pruning_fn(pruning_fn,**pruning_kwargs)elifself._is_pruning_method(pruning_fn):ifnotuse_global_unstructured:raiseMisconfigurationException("PyTorch `BasePruningMethod` is currently only supported with `use_global_unstructured=True`.")else:raiseMisconfigurationException(f"`pruning_fn` is expected to be a str in {list(_PYTORCH_PRUNING_FUNCTIONS.keys())}"f" or a PyTorch `BasePruningMethod`. Found: {pruning_fn}."" HINT: if passing a `BasePruningMethod`, pass the the class, not an instance")# need to ignore typing here since pytorch base class does not define the PRUNING_TYPE attributeifuse_global_unstructuredandpruning_fn.PRUNING_TYPE!="unstructured":# type: ignoreraiseMisconfigurationException('Only the "unstructured" PRUNING_TYPE is supported with `use_global_unstructured=True`.'# type: ignoref" Found method {pruning_fn} of type {pruning_fn.PRUNING_TYPE}. ")self.pruning_fn=pruning_fnself._apply_pruning=apply_pruningself._make_pruning_permanent=make_pruning_permanentifnot(isinstance(amount,(int,float))orcallable(amount)):raiseMisconfigurationException("`amount` should be provided and be either an int, a float or Callable function.")self.amount=amountifverbosenotin(0,1,2):raiseMisconfigurationException("`verbose` must be any of (0, 1, 2)")self._verbose=verbose
[docs]deffilter_parameters_to_prune(self,parameters_to_prune:_PARAM_LIST=())->_PARAM_LIST:"""This function can be overridden to control which module to prune."""returnparameters_to_prune
def_create_pruning_fn(self,pruning_fn:str,**kwargs:Any)->Union[Callable,pytorch_prune.BasePruningMethod]:"""This function takes `pruning_fn`, a function name. IF use_global_unstructured, pruning_fn will be resolved into its associated ``PyTorch BasePruningMethod`` ELSE, pruning_fn will be resolved into its function counterpart from `torch.nn.utils.prune`. """pruning_meth=(_PYTORCH_PRUNING_METHOD[pruning_fn]ifself._use_global_unstructuredelse_PYTORCH_PRUNING_FUNCTIONS[pruning_fn])assertcallable(pruning_meth),"Selected pruning method is not callable"ifself._use_global_unstructured:self._global_kwargs=kwargs# save the function __name__ now because partial does not include it# and there are issues setting the attribute manually in ddp.self._pruning_method_name=pruning_meth.__name__ifself._use_global_unstructured:returnpruning_methreturnModelPruning._wrap_pruning_fn(pruning_meth,**kwargs)@staticmethoddef_wrap_pruning_fn(pruning_fn:Callable,**kwargs:Any)->Callable:returnpartial(pruning_fn,**kwargs)
[docs]defmake_pruning_permanent(self,module:nn.Module)->None:"""Removes pruning buffers from any pruned modules. Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180 """for_,moduleinmodule.named_modules():forkinlist(module._forward_pre_hooks):hook=module._forward_pre_hooks[k]ifisinstance(hook,pytorch_prune.BasePruningMethod):hook.remove(module)delmodule._forward_pre_hooks[k]
[docs]defapply_lottery_ticket_hypothesis(self)->None:r""" Lottery ticket hypothesis algorithm (see page 2 of the paper): 1. Randomly initialize a neural network :math:`f(x; \theta_0)` (where :math:`\theta_0 \sim \mathcal{D}_\theta`). 2. Train the network for :math:`j` iterations, arriving at parameters :math:`\theta_j`. 3. Prune :math:`p\%` of the parameters in :math:`\theta_j`, creating a mask :math:`m`. 4. Reset the remaining parameters to their values in :math:`\theta_0`, creating the winning ticket :math:`f(x; m \odot \theta_0)`. This function implements the step 4. The ``resample_parameters`` argument can be used to reset the parameters with a new :math:`\theta_z \sim \mathcal{D}_\theta` """# noqa: E501assertself._original_layersisnotNonefordinself._original_layers.values():copy=d["data"]names=d["names"]ifself._resample_parametersandhasattr(copy,"reset_parameters")andcallable(copy.reset_parameters):copy=deepcopy(copy)# keep the original parameterscopy.reset_parameters()fori,nameinnames:new,new_name=self._parameters_to_prune[i]self._copy_param(new,copy,name)
[docs]defapply_pruning(self,amount:Union[int,float])->None:"""Applies pruning to ``parameters_to_prune``."""ifself._verbose:prev_stats=[self._get_pruned_stats(m,n)form,ninself._parameters_to_prune]ifself._use_global_unstructured:self._apply_global_pruning(amount)else:self._apply_local_pruning(amount)ifself._verbose:curr_stats=[self._get_pruned_stats(m,n)form,ninself._parameters_to_prune]self._log_sparsity_stats(prev_stats,curr_stats,amount=amount)
@rank_zero_onlydef_log_sparsity_stats(self,prev:List[Tuple[int,int]],curr:List[Tuple[int,int]],amount:Union[int,float]=0)->None:total_params=sum(p.numel()forlayer,_inself._parameters_to_pruneforpinlayer.parameters())prev_total_zeros=sum(zerosforzeros,_inprev)curr_total_zeros=sum(zerosforzeros,_incurr)log.info(f"Applied `{self._pruning_method_name}`. Pruned:"f" {prev_total_zeros}/{total_params} ({prev_total_zeros/total_params:.2%}) ->"f" {curr_total_zeros}/{total_params} ({curr_total_zeros/total_params:.2%})")ifself._verbose==2:fori,(module,name)inenumerate(self._parameters_to_prune):prev_mask_zeros,prev_mask_size=prev[i]curr_mask_zeros,curr_mask_size=curr[i]log.info(f"Applied `{self._pruning_method_name}` to `{module!r}.{name}` with amount={amount}. Pruned:"f" {prev_mask_zeros} ({prev_mask_zeros/prev_mask_size:.2%}) ->"f" {curr_mask_zeros} ({curr_mask_zeros/curr_mask_size:.2%})")
[docs]defsetup(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",stage:Optional[str]=None)->None:parameters_to_prune=self.sanitize_parameters_to_prune(pl_module,self._parameters_to_prune,parameter_names=self._parameter_names)self._parameters_to_prune=self.filter_parameters_to_prune(parameters_to_prune)ifself._use_lottery_ticket_hypothesis:# group modules by id. Each entry has a copy of the initial data# and a list of the associated parameter names to pruneself._original_layers={}fori,(module,name)inenumerate(self._parameters_to_prune):id_=id(module)self._original_layers.setdefault(id_,_LayerRef(data=deepcopy(module),names=[]))self._original_layers[id_]["names"].append((i,name))
[docs]defon_train_end(self,trainer:"pl.Trainer",pl_module:LightningModule)->None:ifself._make_pruning_permanent:rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint")self.make_pruning_permanent(pl_module)
def_make_pruning_permanent_on_state_dict(self,pl_module:LightningModule)->Dict[str,Any]:state_dict=pl_module.state_dict()# find the mask and the original weights.map_pruned_params={k.replace("_mask","")forkinstate_dict.keys()ifk.endswith("_mask")}fortensor_nameinmap_pruned_params:orig=state_dict.pop(tensor_name+"_orig")mask=state_dict.pop(tensor_name+"_mask")# make weights permanentstate_dict[tensor_name]=mask.to(dtype=orig.dtype)*origdefmove_to_cpu(tensor:torch.Tensor)->torch.Tensor:# each tensor and move them on cpureturntensor.cpu()returnapply_to_collection(state_dict,torch.Tensor,move_to_cpu)
[docs]defon_save_checkpoint(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule",checkpoint:Dict[str,Any])->Optional[dict]:ifself._make_pruning_permanent:rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint")# manually prune the weights so training can keep going with the same bufferscheckpoint["state_dict"]=self._make_pruning_permanent_on_state_dict(pl_module)
[docs]@staticmethoddefsanitize_parameters_to_prune(pl_module:LightningModule,parameters_to_prune:_PARAM_LIST=(),parameter_names:Sequence[str]=())->_PARAM_LIST:"""This function is responsible of sanitizing ``parameters_to_prune`` and ``parameter_names``. If ``parameters_to_prune is None``, it will be generated with all parameters of the model. Raises: MisconfigurationException: If ``parameters_to_prune`` doesn't exist in the model, or if ``parameters_to_prune`` is neither a list nor a tuple. """parameters=parameter_namesorModelPruning.PARAMETER_NAMEScurrent_modules=[mforminpl_module.modules()ifnotisinstance(m,_MODULE_CONTAINERS)]ifnotparameters_to_prune:parameters_to_prune=[(m,p)forpinparametersformincurrent_modulesifgetattr(m,p,None)isnotNone]elif(isinstance(parameters_to_prune,(list,tuple))andlen(parameters_to_prune)>0andall(len(p)==2forpinparameters_to_prune)andall(isinstance(a,nn.Module)andisinstance(b,str)fora,binparameters_to_prune)):missing_modules,missing_parameters=[],[]formodule,nameinparameters_to_prune:ifmodulenotincurrent_modules:missing_modules.append(module)continueifnothasattr(module,name):missing_parameters.append(name)ifmissing_modulesormissing_parameters:raiseMisconfigurationException("Some provided `parameters_to_tune` don't exist in the model."f" Found missing modules: {missing_modules} and missing parameters: {missing_parameters}")else:raiseMisconfigurationException("The provided `parameters_to_prune` should either be list of tuple"" with 2 elements: (nn.Module, parameter_name_to_prune) or None")returnparameters_to_prune
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.