Shortcuts

BaseFinetuning

class pytorch_lightning.callbacks.BaseFinetuning[source]

Bases: pytorch_lightning.callbacks.callback.Callback

This class implements the base logic for writing your own Finetuning Callback.

Override freeze_before_training and finetune_function methods with your own logic.

freeze_before_training: This method is called before configure_optimizers

and should be used to freeze any modules parameters.

finetune_function: This method is called on every train epoch start and should be used to

unfreeze any parameters. Those parameters needs to be added in a new param_group within the optimizer.

Note

Make sure to filter the parameters based on requires_grad.

Example:

>>> from torch.optim import Adam
>>> class MyModel(pl.LightningModule):
...     def configure_optimizer(self):
...         # Make sure to filter the parameters based on `requires_grad`
...         return Adam(filter(lambda p: p.requires_grad, self.parameters()))
...
>>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning):
...     def __init__(self, unfreeze_at_epoch=10):
...         super().__init__()
...         self._unfreeze_at_epoch = unfreeze_at_epoch
...
...     def freeze_before_training(self, pl_module):
...         # freeze any module you want
...         # Here, we are freezing `feature_extractor`
...         self.freeze(pl_module.feature_extractor)
...
...     def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx):
...         # When `current_epoch` is 10, feature_extractor will start training.
...         if current_epoch == self._unfreeze_at_epoch:
...             self.unfreeze_and_add_param_group(
...                 modules=pl_module.feature_extractor,
...                 optimizer=optimizer,
...                 train_bn=True,
...             )
static filter_on_optimizer(optimizer, params)[source]

This function is used to exclude any parameter which already exists in this optimizer.

Parameters:
  • optimizer (Optimizer) – Optimizer used for parameter exclusion

  • params (Iterable) – Iterable of parameters used to check against the provided optimizer

Return type:

List

Returns:

List of parameters not contained in this optimizer param groups

static filter_params(modules, train_bn=True, requires_grad=True)[source]

Yields the requires_grad parameters of a given module or list of modules.

Parameters:
  • modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

  • train_bn (bool) – Whether to train BatchNorm module

  • requires_grad (bool) – Whether to create a generator for trainable or non-trainable parameters.

Return type:

Generator

Returns:

Generator

finetune_function(pl_module, epoch, optimizer, opt_idx)[source]

Override to add your unfreeze logic.

Return type:

None

static flatten_modules(modules)[source]

This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.

Parameters:

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type:

List[Module]

Returns:

List of modules

static freeze(modules, train_bn=True)[source]

Freezes the parameters of the provided modules.

Parameters:
Return type:

None

Returns:

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type:

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (Dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

static make_trainable(modules)[source]

Unfreezes the parameters of the provided modules.

Parameters:

modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A given module or an iterable of modules

Return type:

None

on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Called when the epoch begins.

Return type:

None

setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

Dict[str, Any]

Returns:

A dictionary containing callback state.

static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]

Unfreezes a module and adds its parameters to an optimizer.

Parameters:
  • modules (Union[Module, Iterable[Union[Module, Iterable]]]) – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.

  • optimizer (Optimizer) – The provided optimizer will receive new parameters and will add them to add_param_group

  • lr (Optional[float]) – Learning rate for the new param group.

  • initial_denom_lr (float) – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.

  • train_bn (bool) – Whether to train the BatchNormalization layers.

Return type:

None