BaseFinetuning¶
- class pytorch_lightning.callbacks.BaseFinetuning[source]¶
Bases:
pytorch_lightning.callbacks.base.Callback
This class implements the base logic for writing your own Finetuning Callback.
Override
freeze_before_training
andfinetune_function
methods with your own logic.freeze_before_training
: This method is called beforeconfigure_optimizers
and should be used to freeze any modules parameters.
finetune_function
: This method is called on every train epoch start and should be used tounfreeze
any parameters. Those parameters needs to be added in a newparam_group
within the optimizer.
Note
Make sure to filter the parameters based on
requires_grad
.Example:
>>> from torch.optim import Adam >>> class MyModel(pl.LightningModule): ... def configure_optimizer(self): ... # Make sure to filter the parameters based on `requires_grad` ... return Adam(filter(lambda p: p.requires_grad, self.parameters)) ... >>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning): ... def __init__(self, unfreeze_at_epoch=10): ... self._unfreeze_at_epoch = unfreeze_at_epoch ... ... def freeze_before_training(self, pl_module): ... # freeze any module you want ... # Here, we are freezing `feature_extractor` ... self.freeze(pl_module.feature_extractor) ... ... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): ... # When `current_epoch` is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( ... modules=pl_module.feature_extractor, ... optimizer=optimizer, ... train_bn=True, ... )
- static filter_on_optimizer(optimizer, params)[source]¶
This function is used to exclude any parameter which already exists in this optimizer.
- static filter_params(modules, train_bn=True, requires_grad=True)[source]¶
Yields the requires_grad parameters of a given module or list of modules.
- Parameters
- Return type
- Returns
Generator
- finetune_function(pl_module, epoch, optimizer, opt_idx)[source]¶
Override to add your unfreeze logic.
- Return type
- static flatten_modules(modules)[source]¶
This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves.
- on_before_accelerator_backend_setup(trainer, pl_module)[source]¶
Called before accelerator is being setup.
- on_load_checkpoint(trainer, pl_module, callback_state)[source]¶
Called when loading a model checkpoint, use to reload state.
- Parameters
Note
The
on_load_checkpoint
won’t be called with an undefined state. If youron_load_checkpoint
hook behavior doesn’t rely on a state, you will still need to overrideon_save_checkpoint
to return adummy state
.- Return type
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a model checkpoint, use to persist state.
- static unfreeze_and_add_param_group(modules, optimizer, lr=None, initial_denom_lr=10.0, train_bn=True)[source]¶
Unfreezes a module and adds its parameters to an optimizer.
- Parameters
modules¶ (
Union
[Module
,Iterable
[Union
[Module
,Iterable
]]]) – A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group.optimizer¶ (
Optimizer
) – The provided optimizer will receive new parameters and will add them to add_param_grouplr¶ (
Optional
[float
]) – Learning rate for the new param group.initial_denom_lr¶ (
float
) – If no lr is provided, the learning from the first param group will be used and divided by initial_denom_lr.train_bn¶ (
bool
) – Whether to train the BatchNormalization layers.
- Return type