BackboneFinetuning
- class pytorch_lightning.callbacks.BackboneFinetuning(unfreeze_backbone_at_epoch=10, lambda_func=<function multiplicative>, backbone_initial_ratio_lr=0.1, backbone_initial_lr=None, should_align=True, initial_denom_lr=10.0, train_bn=True, verbose=False, round=12)[source]
Bases:
pytorch_lightning.callbacks.finetuning.BaseFinetuning
Finetune a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and
should_align
is set to True, it will align with it for the rest of the training.- Parameters
unfreeze_backbone_at_epoch (
int
) – Epoch at which the backbone will be unfreezed.lambda_func (
Callable
) – Scheduling function for increasing backbone learning rate.backbone_initial_ratio_lr (
float
) – Used to scale down the backbone learning rate compared to rest of modelbackbone_initial_lr (
Optional
[float
]) – Optional, Inital learning rate for the backbone. By default, we will use current_learning / backbone_initial_ratio_lrshould_align (
bool
) – Wheter to align with current learning rate when backbone learning reaches it.initial_denom_lr (
float
) – When unfreezing the backbone, the intial learning rate will current_learning_rate / initial_denom_lr.train_bn (
bool
) – Wheter to make Batch Normalization trainable.verbose (
bool
) – Display current learning rate for model and backboneround (
int
) – Precision for displaying learning rate
Example:
>>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import BackboneFinetuning >>> multiplicative = lambda epoch: 1.5 >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) >>> trainer = Trainer(callbacks=[backbone_finetuning])
- finetune_function(pl_module, epoch, optimizer, opt_idx)[source]
Called when the epoch begins.
- freeze_before_training(pl_module)[source]
Override to add your freeze logic
- on_fit_start(trainer, pl_module)[source]
- Raises
MisconfigurationException – If LightningModule has no nn.Module backbone attribute.
- on_load_checkpoint(trainer, pl_module, callback_state)[source]
Called when loading a model checkpoint, use to reload state.
- Parameters
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.callback_state (
Dict
[int
,List
[Dict
[str
,Any
]]]) – the callback state returned byon_save_checkpoint
.
- Return type
Note
The
on_load_checkpoint
won’t be called with an undefined state. If youron_load_checkpoint
hook behavior doesn’t rely on a state, you will still need to overrideon_save_checkpoint
to return adummy state
.
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]
Called when saving a model checkpoint, use to persist state.
- Parameters
trainer (
Trainer
) – the currentTrainer
instance.pl_module (
LightningModule
) – the currentLightningModule
instance.checkpoint (
Dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type
- Returns
The callback state.