Source code for pytorch_lightning.utilities.parsing
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities used for parameter parsing."""importcopyimportinspectimportpickleimporttypesfromargparseimportNamespacefromdataclassesimportfields,is_dataclassfromtypingimportAny,Dict,List,Optional,Sequence,Tuple,Type,Unionfromtorchimportnnfromtyping_extensionsimportLiteralimportpytorch_lightningasplfrompytorch_lightning.utilities.rank_zeroimportrank_zero_warn
[docs]defstr_to_bool_or_str(val:str)->Union[str,bool]:"""Possibly convert a string representation of truth to bool. Returns the input otherwise. Based on the python implementation distutils.utils.strtobool. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. """lower=val.lower()iflowerin("y","yes","t","true","on","1"):returnTrueiflowerin("n","no","f","false","off","0"):returnFalsereturnval
[docs]defstr_to_bool(val:str)->bool:"""Convert a string representation of truth to bool. True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values are 'n', 'no', 'f', 'false', 'off', and '0'. Raises: ValueError: If ``val`` isn't in one of the aforementioned true or false values. >>> str_to_bool('YES') True >>> str_to_bool('FALSE') False """val_converted=str_to_bool_or_str(val)ifisinstance(val_converted,bool):returnval_convertedraiseValueError(f"invalid truth value {val_converted}")
[docs]defstr_to_bool_or_int(val:str)->Union[bool,int,str]:"""Convert a string representation to truth of bool if possible, or otherwise try to convert it to an int. >>> str_to_bool_or_int("FALSE") False >>> str_to_bool_or_int("1") True >>> str_to_bool_or_int("2") 2 >>> str_to_bool_or_int("abc") 'abc' """val_converted=str_to_bool_or_str(val)ifisinstance(val_converted,bool):returnval_convertedtry:returnint(val_converted)exceptValueError:returnval_converted
[docs]defis_picklable(obj:object)->bool:"""Tests if an object can be pickled."""try:pickle.dumps(obj)returnTrueexcept(pickle.PickleError,AttributeError,RuntimeError):returnFalse
[docs]defclean_namespace(hparams:Union[Dict[str,Any],Namespace])->None:"""Removes all unpicklable entries from hparams."""hparams_dict=hparamsifisinstance(hparams,Namespace):hparams_dict=hparams.__dict__del_attrs=[kfork,vinhparams_dict.items()ifnotis_picklable(v)]forkindel_attrs:rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")delhparams_dict[k]
[docs]defparse_class_init_keys(cls:Type["pl.LightningModule"])->Tuple[str,Optional[str],Optional[str]]:"""Parse key words for standard ``self``, ``*args`` and ``**kwargs``. Examples: >>> class Model(): ... def __init__(self, hparams, *my_args, anykw=42, **my_kwargs): ... pass >>> parse_class_init_keys(Model) ('self', 'my_args', 'my_kwargs') """init_parameters=inspect.signature(cls.__init__).parameters# docs claims the params are always ordered# https://docs.python.org/3/library/inspect.html#inspect.Signature.parametersinit_params=list(init_parameters.values())# self is always firstn_self=init_params[0].namedef_get_first_if_any(params:List[inspect.Parameter],param_type:Literal[inspect._ParameterKind.VAR_POSITIONAL,inspect._ParameterKind.VAR_KEYWORD],)->Optional[str]:forpinparams:ifp.kind==param_type:returnp.namereturnNonen_args=_get_first_if_any(init_params,inspect.Parameter.VAR_POSITIONAL)n_kwargs=_get_first_if_any(init_params,inspect.Parameter.VAR_KEYWORD)returnn_self,n_args,n_kwargs
defget_init_args(frame:types.FrameType)->Dict[str,Any]:_,_,_,local_vars=inspect.getargvalues(frame)if"__class__"notinlocal_vars:return{}cls=local_vars["__class__"]init_parameters=inspect.signature(cls.__init__).parametersself_var,args_var,kwargs_var=parse_class_init_keys(cls)filtered_vars=[nfornin(self_var,args_var,kwargs_var)ifn]exclude_argnames=(*filtered_vars,"__class__","frame","frame_args")# only collect variables that appear in the signaturelocal_args={k:local_vars[k]forkininit_parameters.keys()}# kwargs_var might be None => raised an error by mypyifkwargs_var:local_args.update(local_args.get(kwargs_var,{}))local_args={k:vfork,vinlocal_args.items()ifknotinexclude_argnames}returnlocal_args
[docs]defcollect_init_args(frame:types.FrameType,path_args:List[Dict[str,Any]],inside:bool=False)->List[Dict[str,Any]]:"""Recursively collects the arguments passed to the child constructors in the inheritance tree. Args: frame: the current stack frame path_args: a list of dictionaries containing the constructor args in all parent classes inside: track if we are inside inheritance path, avoid terminating too soon Return: A list of dictionaries where each dictionary contains the arguments passed to the constructor at that level. The last entry corresponds to the constructor call of the most specific class in the hierarchy. """_,_,_,local_vars=inspect.getargvalues(frame)# frame.f_back must be of a type types.FrameType for get_init_args/collect_init_args due to mypyifnotisinstance(frame.f_back,types.FrameType):returnpath_argsif"__class__"inlocal_vars:local_args=get_init_args(frame)# recursive updatepath_args.append(local_args)returncollect_init_args(frame.f_back,path_args,inside=True)ifnotinside:returncollect_init_args(frame.f_back,path_args,inside)returnpath_args
[docs]defsave_hyperparameters(obj:Any,*args:Any,ignore:Optional[Union[Sequence[str],str]]=None,frame:Optional[types.FrameType]=None)->None:"""See :meth:`~pytorch_lightning.LightningModule.save_hyperparameters`"""iflen(args)==1andnotisinstance(args,str)andnotargs[0]:# args[0] is an empty containerreturnifnotframe:current_frame=inspect.currentframe()# inspect.currentframe() return type is Optional[types.FrameType]: current_frame.f_back called only if availableifcurrent_frame:frame=current_frame.f_backifnotisinstance(frame,types.FrameType):raiseAttributeError("There is no `frame` available while being required.")ifis_dataclass(obj):init_args={f.name:getattr(obj,f.name)forfinfields(obj)}else:init_args={}forlocal_argsincollect_init_args(frame,[]):init_args.update(local_args)ifignoreisNone:ignore=[]elifisinstance(ignore,str):ignore=[ignore]elifisinstance(ignore,(list,tuple)):ignore=[argforarginignoreifisinstance(arg,str)]ignore=list(set(ignore))init_args={k:vfork,vininit_args.items()ifknotinignore}ifnotargs:# take all argumentshp=init_argsobj._hparams_name="kwargs"ifhpelseNoneelse:# take only listed arguments in `save_hparams`isx_non_str=[ifori,arginenumerate(args)ifnotisinstance(arg,str)]iflen(isx_non_str)==1:hp=args[isx_non_str[0]]cand_names=[kfork,vininit_args.items()ifv==hp]obj._hparams_name=cand_names[0]ifcand_nameselseNoneelse:hp={arg:init_args[arg]forarginargsifisinstance(arg,str)}obj._hparams_name="kwargs"# `hparams` are expected hereobj._set_hparams(hp)# make deep copy so there is not other runtime changes reflectedobj._hparams_initial=copy.deepcopy(obj._hparams)fork,vinobj._hparams.items():ifisinstance(v,nn.Module):rank_zero_warn(f"Attribute {k!r} is an instance of `nn.Module` and is already saved during checkpointing."f" It is recommended to ignore them using `self.save_hyperparameters(ignore=[{k!r}])`.")
def_lightning_get_all_attr_holders(model:"pl.LightningModule",attribute:str)->List[Any]:"""Special attribute finding for Lightning. Gets all of the objects or dicts that holds attribute. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """trainer=getattr(model,"trainer",None)holders:List[Any]=[]# Check if attribute in modelifhasattr(model,attribute):holders.append(model)# Check if attribute in model.hparams, either namespace or dictifhasattr(model,"hparams"):ifattributeinmodel.hparams:holders.append(model.hparams)# Check if the attribute in datamodule (datamodule gets registered in Trainer)iftrainerisnotNoneandtrainer.datamoduleisnotNoneandhasattr(trainer.datamodule,attribute):holders.append(trainer.datamodule)returnholdersdef_lightning_get_first_attr_holder(model:"pl.LightningModule",attribute:str)->Optional[Any]:"""Special attribute finding for Lightning. Gets the object or dict that holds attribute, or None. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule, returns the last one that has it. """holders=_lightning_get_all_attr_holders(model,attribute)iflen(holders)==0:returnNone# using the last holder to preserve backwards compatibilityreturnholders[-1]
[docs]deflightning_hasattr(model:"pl.LightningModule",attribute:str)->bool:"""Special hasattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. """return_lightning_get_first_attr_holder(model,attribute)isnotNone
[docs]deflightning_getattr(model:"pl.LightningModule",attribute:str)->Optional[Any]:"""Special getattr for Lightning. Checks for attribute in model namespace, the old hparams namespace/dict, and the datamodule. Raises: AttributeError: If ``model`` doesn't have ``attribute`` in any of model namespace, the hparams namespace/dict, and the datamodule. """holder=_lightning_get_first_attr_holder(model,attribute)ifholderisNone:raiseAttributeError(f"{attribute} is neither stored in the model namespace"" nor the `hparams` namespace/dict, nor the datamodule.")ifisinstance(holder,dict):returnholder[attribute]returngetattr(holder,attribute)
[docs]deflightning_setattr(model:"pl.LightningModule",attribute:str,value:Any)->None:"""Special setattr for Lightning. Checks for attribute in model namespace and the old hparams namespace/dict. Will also set the attribute on datamodule, if it exists. Raises: AttributeError: If ``model`` doesn't have ``attribute`` in any of model namespace, the hparams namespace/dict, and the datamodule. """holders=_lightning_get_all_attr_holders(model,attribute)iflen(holders)==0:raiseAttributeError(f"{attribute} is neither stored in the model namespace"" nor the `hparams` namespace/dict, nor the datamodule.")forholderinholders:ifisinstance(holder,dict):holder[attribute]=valueelse:setattr(holder,attribute,value)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.