BackboneFinetuning

class lightning.pytorch.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, rounding=12)[source]

Bases: BaseFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

When the backbone learning rate reaches the current model learning rate and should_align is set to True, it will align with it for the rest of the training.

Parameters:
  • unfreeze_backbone_at_epoch (int) – Epoch at which the backbone will be unfreezed.

  • lambda_func (Callable) – Scheduling function for increasing backbone learning rate.

  • backbone_initial_ratio_lr (float) – Used to scale down the backbone learning rate compared to rest of model

  • backbone_initial_lr (Optional[float]) – Optional, Initial learning rate for the backbone. By default, we will use current_learning /  backbone_initial_ratio_lr

  • should_align (bool) – Whether to align with current learning rate when backbone learning reaches it.

  • initial_denom_lr (float) – When unfreezing the backbone, the initial learning rate will current_learning_rate /  initial_denom_lr.

  • train_bn (bool) – Whether to make Batch Normalization trainable.

  • verbose (bool) – Display current learning rate for model and backbone

  • rounding (int) – Precision for displaying learning rate

Example:

>>> from lightning.pytorch import Trainer
>>> from lightning.pytorch.callbacks import BackboneFinetuning
>>> multiplicative = lambda epoch: 1.5
>>> backbone_finetuning = BackboneFinetuning(200, multiplicative)
>>> trainer = Trainer(callbacks=[backbone_finetuning])
finetune_function(pl_module, epoch, optimizer)[source]

Called when the epoch begins.

Return type:

None

freeze_before_training(pl_module)[source]

Override to add your freeze logic.

Return type:

None

load_state_dict(state_dict)[source]

Called when loading a checkpoint, implement to reload callback state given callback’s state_dict.

Parameters:

state_dict (dict[str, Any]) – the callback state returned by state_dict.

Return type:

None

on_fit_start(trainer, pl_module)[source]
Raises:

MisconfigurationException – If LightningModule has no nn.Module backbone attribute.

Return type:

None

state_dict()[source]

Called when saving a checkpoint, implement to generate callback’s state_dict.

Return type:

dict[str, Any]

Returns:

A dictionary containing callback state.