ModelPruning

class lightning.pytorch.callbacks.ModelPruning(pruning_fn, parameters_to_prune=(), parameter_names=None, use_global_unstructured=True, amount=0.5, apply_pruning=True, make_pruning_permanent=True, use_lottery_ticket_hypothesis=True, resample_parameters=False, pruning_dim=None, pruning_norm=None, verbose=0, prune_on_train_epoch_end=True)[source]

Bases: Callback

Model pruning Callback, using PyTorch’s prune utilities. This callback is responsible of pruning networks parameters during training.

To learn more about pruning with PyTorch, please take a look at this tutorial.

Warning

This is an experimental feature.

parameters_to_prune = [(model.mlp_1, "weight"), (model.mlp_2, "weight")]

trainer = Trainer(
    callbacks=[
        ModelPruning(
            pruning_fn="l1_unstructured",
            parameters_to_prune=parameters_to_prune,
            amount=0.01,
            use_global_unstructured=True,
        )
    ]
)

When parameters_to_prune is None, parameters_to_prune will contain all parameters from the model. The user can override filter_parameters_to_prune to filter any nn.Module to be pruned.

Parameters:
  • pruning_fn (Union[Callable, str]) – Function from torch.nn.utils.prune module or your own PyTorch BasePruningMethod subclass. Can also be string e.g. “l1_unstructured”. See pytorch docs for more details.

  • parameters_to_prune (Sequence[tuple[Module, str]]) – List of tuples (nn.Module, "parameter_name_string").

  • parameter_names (Optional[list[str]]) – List of parameter names to be pruned from the nn.Module. Can either be "weight" or "bias".

  • use_global_unstructured (bool) – Whether to apply pruning globally on the model. If parameters_to_prune is provided, global unstructured will be restricted on them.

  • amount (Union[int, float, Callable[[int], Union[int, float]]]) –

    Quantity of parameters to prune:

    • float. Between 0.0 and 1.0. Represents the fraction of parameters to prune.

    • int. Represents the absolute number of parameters to prune.

    • Callable. For dynamic values. Will be called every epoch. Should return a value.

  • apply_pruning (Union[bool, Callable[[int], bool]]) –

    Whether to apply pruning.

    • bool. Always apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • make_pruning_permanent (bool) – Whether to remove all reparametrization pre-hooks and apply masks when training ends or the model is saved.

  • use_lottery_ticket_hypothesis (Union[bool, Callable[[int], bool]]) –

    See The lottery ticket hypothesis:

    • bool. Whether to apply it or not.

    • Callable[[epoch], bool]. For dynamic values. Will be called every epoch.

  • resample_parameters (bool) – Used with use_lottery_ticket_hypothesis. If True, the model parameters will be resampled, otherwise, the exact original parameters will be used.

  • pruning_dim (Optional[int]) – If you are using a structured pruning method you need to specify the dimension.

  • pruning_norm (Optional[int]) – If you are using ln_structured you need to specify the norm.

  • verbose (int) – Verbosity level. 0 to disable, 1 to log overall sparsity, 2 to log per-layer sparsity

  • prune_on_train_epoch_end (bool) – whether to apply pruning at the end of the training epoch. If this is False, then the check runs at the end of the validation epoch.

Raises:

MisconfigurationException – If parameter_names is neither "weight" nor "bias", if the provided pruning_fn is not supported, if pruning_dim is not provided when "unstructured", if pruning_norm is not provided when "ln_structured", if pruning_fn is neither str nor torch.nn.utils.prune.BasePruningMethod, or if amount is none of int, float and Callable.

apply_lottery_ticket_hypothesis()[source]

Lottery ticket hypothesis algorithm (see page 2 of the paper): :rtype: None

  1. Randomly initialize a neural network \(f(x; \theta_0)\) (where \(\theta_0 \sim \mathcal{D}_\theta\)).

  2. Train the network for \(j\) iterations, arriving at parameters \(\theta_j\).

  3. Prune \(p\%\) of the parameters in \(\theta_j\), creating a mask \(m\).

  4. Reset the remaining parameters to their values in \(\theta_0\), creating the winning ticket \(f(x; m \odot \theta_0)\).

This function implements the step 4.

The resample_parameters argument can be used to reset the parameters with a new \(\theta_z \sim \mathcal{D}_\theta\)

apply_pruning(amount)[source]

Applies pruning to parameters_to_prune.

Return type:

None

filter_parameters_to_prune(parameters_to_prune=())[source]

This function can be overridden to control which module to prune.

Return type:

Sequence[tuple[Module, str]]

make_pruning_permanent(module)[source]

Removes pruning buffers from any pruned modules.

Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/nn/utils/prune.py#L1118-L1122

Return type:

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:
Return type:

None

on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the lightning.pytorch.core.LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss


class MyCallback(L.Callback):
    def on_train_epoch_end(self, trainer, pl_module):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(pl_module.training_step_outputs).mean()
        pl_module.log("training_epoch_mean", epoch_mean)
        # free up the memory
        pl_module.training_step_outputs.clear()
Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type:

None

static sanitize_parameters_to_prune(pl_module, parameters_to_prune=(), parameter_names=())[source]

This function is responsible of sanitizing parameters_to_prune and parameter_names. If parameters_to_prune is None, it will be generated with all parameters of the model.

Raises:

MisconfigurationException – If parameters_to_prune doesn’t exist in the model, or if parameters_to_prune is neither a list nor a tuple.

Return type:

Sequence[tuple[Module, str]]

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None