Source code for pytorch_lightning.callbacks.finetuning
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License.r"""Finetuning Callback^^^^^^^^^^^^^^^^^^^^Freeze and unfreeze models for finetuning purposes"""importloggingfromtypingimportAny,Callable,Dict,Generator,Iterable,List,Optional,Unionimporttorchfromtorch.nnimportModule,ModuleDictfromtorch.nn.modules.batchnormimport_BatchNormfromtorch.optim.optimizerimportOptimizerimportpytorch_lightningasplfrompytorch_lightning.callbacks.baseimportCallbackfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warnlog=logging.getLogger(__name__)defmultiplicative(epoch):return2
[docs]classBaseFinetuning(Callback):r""" This class implements the base logic for writing your own Finetuning Callback. Override ``freeze_before_training`` and ``finetune_function`` methods with your own logic. ``freeze_before_training``: This method is called before ``configure_optimizers`` and should be used to freeze any modules parameters. ``finetune_function``: This method is called on every train epoch start and should be used to ``unfreeze`` any parameters. Those parameters needs to be added in a new ``param_group`` within the optimizer. .. note:: Make sure to filter the parameters based on ``requires_grad``. Example:: >>> from torch.optim import Adam >>> class MyModel(pl.LightningModule): ... def configure_optimizer(self): ... # Make sure to filter the parameters based on `requires_grad` ... return Adam(filter(lambda p: p.requires_grad, self.parameters())) ... >>> class FeatureExtractorFreezeUnfreeze(BaseFinetuning): ... def __init__(self, unfreeze_at_epoch=10): ... super().__init__() ... self._unfreeze_at_epoch = unfreeze_at_epoch ... ... def freeze_before_training(self, pl_module): ... # freeze any module you want ... # Here, we are freezing `feature_extractor` ... self.freeze(pl_module.feature_extractor) ... ... def finetune_function(self, pl_module, current_epoch, optimizer, optimizer_idx): ... # When `current_epoch` is 10, feature_extractor will start training. ... if current_epoch == self._unfreeze_at_epoch: ... self.unfreeze_and_add_param_group( ... modules=pl_module.feature_extractor, ... optimizer=optimizer, ... train_bn=True, ... ) """def__init__(self):self._internal_optimizer_metadata:Dict[int,List[Dict[str,Any]]]={}self._restarting=False
[docs]defload_state_dict(self,state_dict:Dict[str,Any])->None:self._restarting=Trueif"internal_optimizer_metadata"instate_dict:self._internal_optimizer_metadata=state_dict["internal_optimizer_metadata"]else:# compatibility to load from old checkpoints before PR #11887self._internal_optimizer_metadata=state_dict
[docs]defon_fit_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:# restore the param_groups created during the previous training.ifself._restarting:named_parameters=dict(pl_module.named_parameters())foropt_idx,optimizerinenumerate(trainer.optimizers):param_groups=self._apply_mapping_to_param_groups(self._internal_optimizer_metadata[opt_idx],named_parameters)optimizer.param_groups=param_groupsself._restarting=False
[docs]@staticmethoddefflatten_modules(modules:Union[Module,Iterable[Union[Module,Iterable]]])->List[Module]:"""This function is used to flatten a module or an iterable of modules into a list of its leaf modules (modules with no children) and parent modules that have parameters directly themselves. Args: modules: A given module or an iterable of modules Returns: List of modules """ifisinstance(modules,ModuleDict):modules=modules.values()ifisinstance(modules,Iterable):_modules=[]forminmodules:_modules.extend(BaseFinetuning.flatten_modules(m))else:_modules=modules.modules()# Capture all leaf modules as well as parent modules that have parameters directly themselvesreturn[mformin_modulesifnotlist(m.children())orm._parameters]
[docs]@staticmethoddeffilter_params(modules:Union[Module,Iterable[Union[Module,Iterable]]],train_bn:bool=True,requires_grad:bool=True)->Generator:"""Yields the `requires_grad` parameters of a given module or list of modules. Args: modules: A given module or an iterable of modules train_bn: Whether to train BatchNorm module requires_grad: Whether to create a generator for trainable or non-trainable parameters. Returns: Generator """modules=BaseFinetuning.flatten_modules(modules)formodinmodules:ifisinstance(mod,_BatchNorm)andnottrain_bn:continue# recursion could yield duplicate parameters for parent modules w/ parameters so disabling itforparaminmod.parameters(recurse=False):ifparam.requires_grad==requires_grad:yieldparam
[docs]@staticmethoddefmake_trainable(modules:Union[Module,Iterable[Union[Module,Iterable]]])->None:"""Unfreezes the parameters of the provided modules. Args: modules: A given module or an iterable of modules """modules=BaseFinetuning.flatten_modules(modules)formoduleinmodules:# recursion could yield duplicate parameters for parent modules w/ parameters so disabling itforparaminmodule.parameters(recurse=False):param.requires_grad=True
[docs]@staticmethoddeffreeze(modules:Union[Module,Iterable[Union[Module,Iterable]]],train_bn:bool=True)->None:"""Freezes the parameters of the provided modules. Args: modules: A given module or an iterable of modules train_bn: If True, leave the BatchNorm layers in training mode Returns: None """modules=BaseFinetuning.flatten_modules(modules)formodinmodules:ifisinstance(mod,_BatchNorm)andtrain_bn:BaseFinetuning.make_trainable(mod)else:# recursion could yield duplicate parameters for parent modules w/ parameters so disabling itforparaminmod.parameters(recurse=False):param.requires_grad=False
[docs]@staticmethoddeffilter_on_optimizer(optimizer:Optimizer,params:Iterable)->List:"""This function is used to exclude any parameter which already exists in this optimizer. Args: optimizer: Optimizer used for parameter exclusion params: Iterable of parameters used to check against the provided optimizer Returns: List of parameters not contained in this optimizer param groups """out_params=[]removed_params=[]forparaminparams:ifnotany(torch.equal(p,param)forgroupinoptimizer.param_groupsforpingroup["params"]):out_params.append(param)else:removed_params.append(param)ifremoved_params:rank_zero_warn("The provided params to be frozen already exist within another group of this optimizer."" Those parameters will be skipped.\n""HINT: Did you init your optimizer in `configure_optimizer` as such:\n"f" {type(optimizer)}(filter(lambda p: p.requires_grad, self.parameters()), ...) ",)returnout_params
[docs]@staticmethoddefunfreeze_and_add_param_group(modules:Union[Module,Iterable[Union[Module,Iterable]]],optimizer:Optimizer,lr:Optional[float]=None,initial_denom_lr:float=10.0,train_bn:bool=True,)->None:"""Unfreezes a module and adds its parameters to an optimizer. Args: modules: A module or iterable of modules to unfreeze. Their parameters will be added to an optimizer as a new param group. optimizer: The provided optimizer will receive new parameters and will add them to `add_param_group` lr: Learning rate for the new param group. initial_denom_lr: If no lr is provided, the learning from the first param group will be used and divided by `initial_denom_lr`. train_bn: Whether to train the BatchNormalization layers. """BaseFinetuning.make_trainable(modules)params_lr=optimizer.param_groups[0]["lr"]iflrisNoneelsefloat(lr)denom_lr=initial_denom_lriflrisNoneelse1.0params=BaseFinetuning.filter_params(modules,train_bn=train_bn,requires_grad=True)params=BaseFinetuning.filter_on_optimizer(optimizer,params)ifparams:optimizer.add_param_group({"params":params,"lr":params_lr/denom_lr})
@staticmethoddef_apply_mapping_to_param_groups(param_groups:List[Dict[str,Any]],mapping:dict)->List[Dict[str,Any]]:output=[]forginparam_groups:# skip params to save memorygroup_state={k:vfork,ving.items()ifk!="params"}group_state["params"]=[mapping[p]forping["params"]]output.append(group_state)returnoutputdef_store(self,pl_module:"pl.LightningModule",opt_idx:int,num_param_groups:int,current_param_groups:List[Dict[str,Any]],)->None:mapping={p:nforn,pinpl_module.named_parameters()}ifopt_idxnotinself._internal_optimizer_metadata:self._internal_optimizer_metadata[opt_idx]=self._apply_mapping_to_param_groups(current_param_groups,mapping)elifnum_param_groups!=len(current_param_groups):# save new param_groups possibly created by the users.self._internal_optimizer_metadata[opt_idx].extend(self._apply_mapping_to_param_groups(current_param_groups[num_param_groups:],mapping))
[docs]defon_train_epoch_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:"""Called when the epoch begins."""# import is here to avoid circular importsfrompytorch_lightning.loops.utilitiesimport_get_active_optimizersforopt_idx,optimizerin_get_active_optimizers(trainer.optimizers,trainer.optimizer_frequencies):num_param_groups=len(optimizer.param_groups)self.finetune_function(pl_module,trainer.current_epoch,optimizer,opt_idx)current_param_groups=optimizer.param_groupsself._store(pl_module,opt_idx,num_param_groups,current_param_groups)
[docs]deffinetune_function(self,pl_module:"pl.LightningModule",epoch:int,optimizer:Optimizer,opt_idx:int)->None:"""Override to add your unfreeze logic."""raiseNotImplementedError
[docs]deffreeze_before_training(self,pl_module:"pl.LightningModule")->None:"""Override to add your freeze logic."""raiseNotImplementedError
[docs]classBackboneFinetuning(BaseFinetuning):r"""Finetune a backbone model based on a learning rate user-defined scheduling. When the backbone learning rate reaches the current model learning rate and ``should_align`` is set to True, it will align with it for the rest of the training. Args: unfreeze_backbone_at_epoch: Epoch at which the backbone will be unfreezed. lambda_func: Scheduling function for increasing backbone learning rate. backbone_initial_ratio_lr: Used to scale down the backbone learning rate compared to rest of model backbone_initial_lr: Optional, Initial learning rate for the backbone. By default, we will use ``current_learning / backbone_initial_ratio_lr`` should_align: Whether to align with current learning rate when backbone learning reaches it. initial_denom_lr: When unfreezing the backbone, the initial learning rate will ``current_learning_rate / initial_denom_lr``. train_bn: Whether to make Batch Normalization trainable. verbose: Display current learning rate for model and backbone rounding: Precision for displaying learning rate Example:: >>> from pytorch_lightning import Trainer >>> from pytorch_lightning.callbacks import BackboneFinetuning >>> multiplicative = lambda epoch: 1.5 >>> backbone_finetuning = BackboneFinetuning(200, multiplicative) >>> trainer = Trainer(callbacks=[backbone_finetuning]) """def__init__(self,unfreeze_backbone_at_epoch:int=10,lambda_func:Callable=multiplicative,backbone_initial_ratio_lr:float=10e-2,backbone_initial_lr:Optional[float]=None,should_align:bool=True,initial_denom_lr:float=10.0,train_bn:bool=True,verbose:bool=False,rounding:int=12,)->None:super().__init__()self.unfreeze_backbone_at_epoch:int=unfreeze_backbone_at_epochself.lambda_func:Callable=lambda_funcself.backbone_initial_ratio_lr:float=backbone_initial_ratio_lrself.backbone_initial_lr:Optional[float]=backbone_initial_lrself.should_align:bool=should_alignself.initial_denom_lr:float=initial_denom_lrself.train_bn:bool=train_bnself.verbose:bool=verboseself.rounding:int=roundingself.previous_backbone_lr:Optional[float]=None
[docs]defon_fit_start(self,trainer:"pl.Trainer",pl_module:"pl.LightningModule")->None:""" Raises: MisconfigurationException: If LightningModule has no nn.Module `backbone` attribute. """ifhasattr(pl_module,"backbone")andisinstance(pl_module.backbone,Module):returnsuper().on_fit_start(trainer,pl_module)raiseMisconfigurationException("The LightningModule should have a nn.Module `backbone` attribute")
[docs]deffinetune_function(self,pl_module:"pl.LightningModule",epoch:int,optimizer:Optimizer,opt_idx:int)->None:"""Called when the epoch begins."""ifepoch==self.unfreeze_backbone_at_epoch:current_lr=optimizer.param_groups[0]["lr"]initial_backbone_lr=(self.backbone_initial_lrifself.backbone_initial_lrisnotNoneelsecurrent_lr*self.backbone_initial_ratio_lr)self.previous_backbone_lr=initial_backbone_lrself.unfreeze_and_add_param_group(pl_module.backbone,optimizer,initial_backbone_lr,train_bn=self.train_bn,initial_denom_lr=self.initial_denom_lr,)ifself.verbose:log.info(f"Current lr: {round(current_lr,self.rounding)}, "f"Backbone lr: {round(initial_backbone_lr,self.rounding)}")elifepoch>self.unfreeze_backbone_at_epoch:current_lr=optimizer.param_groups[0]["lr"]next_current_backbone_lr=self.lambda_func(epoch+1)*self.previous_backbone_lrnext_current_backbone_lr=(current_lrif(self.should_alignandnext_current_backbone_lr>current_lr)elsenext_current_backbone_lr)optimizer.param_groups[-1]["lr"]=next_current_backbone_lrself.previous_backbone_lr=next_current_backbone_lrifself.verbose:log.info(f"Current lr: {round(current_lr,self.rounding)}, "f"Backbone lr: {round(next_current_backbone_lr,self.rounding)}")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.