Shortcuts

ModelPruning

class pytorch_lightning.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]

Bases: pytorch_lightning.callbacks.callback.Callback

Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.

To learn more about pruning with PyTorch, please take a look at this tutorial.

Warning

ModelPruning is in beta and subject to change.

parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]

trainer = Trainer(
    callbacks=[
        ModelPruning(
            pruning_fn="l1_unstructured",
            parameters_to_prune=parameters_to_prune,
            amount=0.01,
            use_global_unstructured=True,
        )
    ]
)

When parameters_to_prune is None, parameters_to_prune will contain all parameters from the model. The user can override filter_parameters_to_prune to filter any nn.Module to be pruned.

Parameters
  • pruning_fn (Union[Callable, str]) – Function from torch.nn.utils.prune module or your own PyTorch BasePruningMethod subclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details.

  • parameters_to_prune (Sequence[Tuple[Module, str]]) – List of tuples (nn.Module, "parameter_name_string").

  • parameter_names (Optional[List[str]]) – List of parameter names to be pruned from the nn.Module. Can either be "weight" or "bias".

  • use_global_unstructured (bool) – Whether to apply pruning globally on the model. If parameters_to_prune is provided, global unstructured will be restricted on them.

  • amount (Union[int, float, Callable[[int], Union[int, float]]]) –

    Quantity of parameters to prune:

    • float. Between 0.0 and 1.0. Represents the fraction of parameters to prune.

    • int. Represents the absolute number of parameters to prune.

    • Callable. For dynamic values. Will be called every epoch. Should return a value.

  • apply_pruning (Union[bool, Callable[[int], bool]]) –

    Whether to apply pruning.

    • bool. Always apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • make_pruning_permanent (bool) – Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved.

  • use_lottery_ticket_hypothesis (Union[bool, Callable[[int], bool]]) –

    See The lottery ticket hypothesis:

    • bool. Whether to apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • resample_parameters (bool) – Used with use_lottery_ticket_hypothesis. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used.

  • pruning_dim (Optional[int]) – If you are using a structured pruning method you need to specify the dimension.

  • pruning_norm (Optional[int]) – If you are using ln_structured you need to specify the norm.

  • verbose (int) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity

  • prune_on_train_epoch_end (bool) – whether to apply pruning at the end of the training epoch. If this is False, then the check runs at the end of the validation epoch.

Raises

MisconfigurationException – If parameter_names is neither "weight" nor "bias", if the provided pruning_fn is not supported, if pruning_dim is not provided when "unstructured", if pruning_norm is not provided when "ln_structured", if pruning_fn is neither str nor torch.nn.utils.prune.BasePruningMethod, or if amount is none of int, float and Callable.

apply_lottery_ticket_hypothesis()[source]

Lottery ticket hypothesis algorithm (see page 2 of the paper):

  1. Randomly initialize a neural network f(x; \theta_0) (where \theta_0 \sim \mathcal{D}_\theta).

  2. Train the network for j iterations, arriving at parameters \theta_j.

  3. Prune p\% of the parameters in \theta_j, creating a mask m.

  4. Reset the remaining parameters to their values in \theta_0, creating the winning ticket f(x; m \odot \theta_0).

This function implements the step 4.

The resample_parameters argument can be used to reset the parameters with a new \theta_z \sim \mathcal{D}_\theta

Return type

None

apply_pruning(amount)[source]

Applies pruning to parameters_to_prune.

Return type

None

filter_parameters_to_prune(parameters_to_prune=())[source]

This function can be overridden to control which module to prune.

Return type

Sequence[Tuple[Module, str]]

make_pruning_permanent(module)[source]

Removes pruning buffers from any pruned modules.

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122

Return type

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters
Return type

None

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

Return type

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type

None

static sanitize_parameters_to_prune(pl_module, parameters_to_prune=(), parameter_names=())[source]

This function is responsible of sanitizing parameters_to_prune and parameter_names. If parameters_to_prune is None, it will be generated with all parameters of the model.

Raises

MisconfigurationException – If parameters_to_prune doesn’t exist in the model, or if parameters_to_prune is neither a list nor a tuple.

Return type

Sequence[Tuple[Module, str]]

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type

None