Source code for pytorch_lightning.utilities.argparse
# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities for Argument Parsing within Lightning Components."""importinspectimportosfromabcimportABCfromargparseimport_ArgumentGroup,ArgumentParser,Namespacefromastimportliteral_evalfromcontextlibimportsuppressfromfunctoolsimportwrapsfromtypingimportAny,Callable,cast,Dict,List,Tuple,Type,TypeVar,Unionimportpytorch_lightningasplfrompytorch_lightning.utilities.parsingimportstr_to_bool,str_to_bool_or_int,str_to_bool_or_str_T=TypeVar("_T",bound=Callable[...,Any])
[docs]deffrom_argparse_args(cls:Type[ParseArgparserDataType],args:Union[Namespace,ArgumentParser],**kwargs:Any)->ParseArgparserDataType:"""Create an instance from CLI arguments. Eventually use variables from OS environment which are defined as ``"PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"``. Args: cls: Lightning class args: The parser or namespace to take arguments from. Only known arguments will be parsed and passed to the :class:`Trainer`. **kwargs: Additional keyword arguments that may override ones in the parser or namespace. These must be valid Trainer arguments. Examples: >>> from pytorch_lightning import Trainer >>> parser = ArgumentParser(add_help=False) >>> parser = Trainer.add_argparse_args(parser) >>> parser.add_argument('--my_custom_arg', default='something') # doctest: +SKIP >>> args = Trainer.parse_argparser(parser.parse_args("")) >>> trainer = Trainer.from_argparse_args(args, logger=False) """ifisinstance(args,ArgumentParser):args=cls.parse_argparser(args)params=vars(args)# we only want to pass in valid Trainer args, the rest may be user specificvalid_kwargs=inspect.signature(cls.__init__).parameterstrainer_kwargs={name:params[name]fornameinvalid_kwargsifnameinparams}trainer_kwargs.update(**kwargs)returncls(**trainer_kwargs)
[docs]defparse_argparser(cls:Type["pl.Trainer"],arg_parser:Union[ArgumentParser,Namespace])->Namespace:"""Parse CLI arguments, required for custom bool types."""args=arg_parser.parse_args()ifisinstance(arg_parser,ArgumentParser)elsearg_parsertypes_default={arg:(arg_types,arg_default)forarg,arg_types,arg_defaultinget_init_arguments_and_types(cls)}modified_args={}fork,vinvars(args).items():ifkintypes_defaultandvisNone:# We need to figure out if the None is due to using nargs="?" or if it comes from the default valuearg_types,arg_default=types_default[k]ifboolinarg_typesandisinstance(arg_default,bool):# Value has been passed as a flag => It is currently None, so we need to set it to True# We always set to True, regardless of the default value.# Users must pass False directly, but when passing nothing True is assumed.# i.e. the only way to disable something that defaults to True is to use the long form:# "--a_default_true_arg False" becomes False, while "--a_default_false_arg" becomes None,# which then becomes True here.v=Truemodified_args[k]=vreturnNamespace(**modified_args)
[docs]defparse_env_variables(cls:Type["pl.Trainer"],template:str="PL_%(cls_name)s_%(cls_argument)s")->Namespace:"""Parse environment arguments if they are defined. Examples: >>> from pytorch_lightning import Trainer >>> parse_env_variables(Trainer) Namespace() >>> import os >>> os.environ["PL_TRAINER_GPUS"] = '42' >>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23' >>> parse_env_variables(Trainer) Namespace(gpus=42) >>> del os.environ["PL_TRAINER_GPUS"] """cls_arg_defaults=get_init_arguments_and_types(cls)env_args={}forarg_name,_,_incls_arg_defaults:env=template%{"cls_name":cls.__name__.upper(),"cls_argument":arg_name.upper()}val=os.environ.get(env)ifnot(valisNoneorval==""):# todo: specify the possible exceptionwithsuppress(Exception):# converting to native types like int/float/boolval=literal_eval(val)env_args[arg_name]=valreturnNamespace(**env_args)
[docs]defget_init_arguments_and_types(cls:Any)->List[Tuple[str,Tuple,Any]]:r"""Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: (argument name, set with argument types, argument default value). Examples: >>> from pytorch_lightning import Trainer >>> args = get_init_arguments_and_types(Trainer) """cls_default_params=inspect.signature(cls).parametersname_type_default=[]forargincls_default_params:arg_type=cls_default_params[arg].annotationarg_default=cls_default_params[arg].defaulttry:arg_types=tuple(arg_type.__args__)except(AttributeError,TypeError):arg_types=(arg_type,)name_type_default.append((arg,arg_types,arg_default))returnname_type_default
[docs]defadd_argparse_args(cls:Type["pl.Trainer"],parent_parser:ArgumentParser,*,use_argument_group:bool=True)->Union[_ArgumentGroup,ArgumentParser]:r"""Extends existing argparse by default attributes for ``cls``. Args: cls: Lightning class parent_parser: The custom cli arguments parser, which will be extended by the class's default arguments. use_argument_group: By default, this is True, and uses ``add_argument_group`` to add a new group. If False, this will use old behavior. Returns: If use_argument_group is True, returns ``parent_parser`` to keep old workflows. If False, will return the new ArgumentParser object. Only arguments of the allowed types (str, float, int, bool) will extend the ``parent_parser``. Raises: RuntimeError: If ``parent_parser`` is not an ``ArgumentParser`` instance Examples: >>> # Option 1: Default usage. >>> import argparse >>> from pytorch_lightning import Trainer >>> parser = argparse.ArgumentParser() >>> parser = Trainer.add_argparse_args(parser) >>> args = parser.parse_args([]) >>> # Option 2: Disable use_argument_group (old behavior). >>> import argparse >>> from pytorch_lightning import Trainer >>> parser = argparse.ArgumentParser() >>> parser = Trainer.add_argparse_args(parser, use_argument_group=False) >>> args = parser.parse_args([]) """ifisinstance(parent_parser,_ArgumentGroup):raiseRuntimeError("Please only pass an ArgumentParser instance.")ifuse_argument_group:group_name=_get_abbrev_qualified_cls_name(cls)parser:Union[_ArgumentGroup,ArgumentParser]=parent_parser.add_argument_group(group_name)else:parser=ArgumentParser(parents=[parent_parser],add_help=False)ignore_arg_names=["self","args","kwargs"]ifhasattr(cls,"get_deprecated_arg_names"):ignore_arg_names+=cls.get_deprecated_arg_names()allowed_types=(str,int,float,bool)# Get symbols from cls or init function.forsymbolin(cls,cls.__init__):args_and_types=get_init_arguments_and_types(symbol)args_and_types=[xforxinargs_and_typesifx[0]notinignore_arg_names]iflen(args_and_types)>0:breakargs_help=_parse_args_from_docstring(cls.__init__.__doc__orcls.__doc__or"")forarg,arg_types,arg_defaultinargs_and_types:arg_types=tuple(atforatinallowed_typesifatinarg_types)ifnotarg_types:# skip argument with not supported typecontinuearg_kwargs:Dict[str,Any]={}ifboolinarg_types:arg_kwargs.update(nargs="?",const=True)# if the only arg type is booliflen(arg_types)==1:use_type:Callable[[str],Union[bool,int,float,str]]=str_to_boolelifintinarg_types:use_type=str_to_bool_or_intelifstrinarg_types:use_type=str_to_bool_or_strelse:# filter out the bool as we need to use more generaluse_type=[atforatinarg_typesifatisnotbool][0]else:use_type=arg_types[0]ifarg=="gpus"orarg=="tpu_cores":use_type=_gpus_allowed_type# hack for types in (int, float)iflen(arg_types)==2andintinset(arg_types)andfloatinset(arg_types):use_type=_int_or_float_type# hack for track_grad_normifarg=="track_grad_norm":use_type=float# hack for precisionifarg=="precision":use_type=_precision_allowed_typeparser.add_argument(f"--{arg}",dest=arg,default=arg_default,type=use_type,help=args_help.get(arg),**arg_kwargs)ifuse_argument_group:returnparent_parserreturnparser
def_parse_args_from_docstring(docstring:str)->Dict[str,str]:arg_block_indent=Nonecurrent_arg=""parsed={}forlineindocstring.split("\n"):stripped=line.lstrip()ifnotstripped:continueline_indent=len(line)-len(stripped)ifstripped.startswith(("Args:","Arguments:","Parameters:")):arg_block_indent=line_indent+4elifarg_block_indentisNone:continueelifline_indent<arg_block_indent:breakelifline_indent==arg_block_indent:current_arg,arg_description=stripped.split(":",maxsplit=1)parsed[current_arg]=arg_description.lstrip()elifline_indent>arg_block_indent:parsed[current_arg]+=f" {stripped}"returnparseddef_gpus_allowed_type(x:str)->Union[int,str]:if","inx:returnstr(x)returnint(x)def_int_or_float_type(x:Union[int,float,str])->Union[int,float]:if"."instr(x):returnfloat(x)returnint(x)def_precision_allowed_type(x:Union[int,str])->Union[int,str]:""" >>> _precision_allowed_type("32") 32 >>> _precision_allowed_type("bf16") 'bf16' """try:returnint(x)exceptValueError:returnxdef_defaults_from_env_vars(fn:_T)->_T:@wraps(fn)definsert_env_defaults(self:Any,*args:Any,**kwargs:Any)->Any:cls=self.__class__# get the classifargs:# in case any args passed move them to kwargs# parse only the argument namescls_arg_names=[arg[0]forarginget_init_arguments_and_types(cls)]# convert args to kwargskwargs.update(dict(zip(cls_arg_names,args)))env_variables=vars(parse_env_variables(cls))# update the kwargs by env variableskwargs=dict(list(env_variables.items())+list(kwargs.items()))# all args were already moved to kwargsreturnfn(self,**kwargs)returncast(_T,insert_env_defaults)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.