# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities for LightningCLI."""importinspectimportosimportsysfromfunctoolsimportpartial,update_wrapperfromtypesimportMethodType,ModuleTypefromtypingimportAny,Callable,Dict,Generator,List,Optional,Set,Tuple,Type,Unionfromunittestimportmockimporttorchimportyamlfromtorch.optimimportOptimizerimportpytorch_lightningasplfrompytorch_lightningimportCallback,LightningDataModule,LightningModule,seed_everything,Trainerfrompytorch_lightning.utilities.cloud_ioimportget_filesystemfrompytorch_lightning.utilities.exceptionsimportMisconfigurationExceptionfrompytorch_lightning.utilities.importsimport_JSONARGPARSE_AVAILABLEfrompytorch_lightning.utilities.metaimportget_all_subclassesfrompytorch_lightning.utilities.model_helpersimportis_overriddenfrompytorch_lightning.utilities.rank_zeroimport_warn,rank_zero_warnfrompytorch_lightning.utilities.typesimportLRSchedulerType,LRSchedulerTypeTuple,LRSchedulerTypeUnionif_JSONARGPARSE_AVAILABLE:fromjsonargparseimportActionConfigFile,ArgumentParser,class_from_function,Namespace,set_config_read_modefromjsonargparse.optionalsimportimport_docstring_parseset_config_read_mode(fsspec_enabled=True)else:locals()["ArgumentParser"]=objectlocals()["Namespace"]=objectclass_Registry(dict):def__call__(self,cls:Type,key:Optional[str]=None,override:bool=False)->Type:"""Registers a class mapped to a name. Args: cls: the class to be mapped. key: the name that identifies the provided class. override: Whether to override an existing key. """ifkeyisNone:key=cls.__name__elifnotisinstance(key,str):raiseTypeError(f"`key` must be a str, found {key}")ifkeynotinselforoverride:self[key]=clsreturnclsdefregister_classes(self,module:ModuleType,base_cls:Type,override:bool=False)->None:"""This function is an utility to register all classes from a module."""forclsinself.get_members(module,base_cls):self(cls=cls,override=override)@staticmethoddefget_members(module:ModuleType,base_cls:Type)->Generator[Type,None,None]:return(clsfor_,clsininspect.getmembers(module,predicate=inspect.isclass)ifissubclass(cls,base_cls)andcls!=base_cls)@propertydefnames(self)->List[str]:"""Returns the registered names."""returnlist(self.keys())@propertydefclasses(self)->Tuple[Type,...]:"""Returns the registered classes."""returntuple(self.values())def__str__(self)->str:returnf"Registered objects: {self.names}"OPTIMIZER_REGISTRY=_Registry()LR_SCHEDULER_REGISTRY=_Registry()CALLBACK_REGISTRY=_Registry()MODEL_REGISTRY=_Registry()DATAMODULE_REGISTRY=_Registry()LOGGER_REGISTRY=_Registry()
def_populate_registries(subclasses:bool)->None:ifsubclasses:# this will register any subclasses from all loaded modules including userlandforclsinget_all_subclasses(torch.optim.Optimizer):OPTIMIZER_REGISTRY(cls)forclsinget_all_subclasses(torch.optim.lr_scheduler._LRScheduler):LR_SCHEDULER_REGISTRY(cls)forclsinget_all_subclasses(pl.Callback):CALLBACK_REGISTRY(cls)forclsinget_all_subclasses(pl.LightningModule):MODEL_REGISTRY(cls)forclsinget_all_subclasses(pl.LightningDataModule):DATAMODULE_REGISTRY(cls)forclsinget_all_subclasses(pl.loggers.LightningLoggerBase):LOGGER_REGISTRY(cls)else:# manually register torch's subclasses and our subclassesOPTIMIZER_REGISTRY.register_classes(torch.optim,Optimizer)LR_SCHEDULER_REGISTRY.register_classes(torch.optim.lr_scheduler,torch.optim.lr_scheduler._LRScheduler)CALLBACK_REGISTRY.register_classes(pl.callbacks,pl.Callback)LOGGER_REGISTRY.register_classes(pl.loggers,pl.loggers.LightningLoggerBase)# `ReduceLROnPlateau` does not subclass `_LRScheduler`LR_SCHEDULER_REGISTRY(cls=ReduceLROnPlateau)
[docs]classLightningArgumentParser(ArgumentParser):"""Extension of jsonargparse's ArgumentParser for pytorch-lightning."""# use class attribute because `parse_args` is only called on the main parser_choices:Dict[str,Tuple[Tuple[Type,...],bool]]={}def__init__(self,*args:Any,**kwargs:Any)->None:"""Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ <https://jsonargparse.readthedocs.io/en/stable/index.html#jsonargparse.ArgumentParser.__init__>`_. """ifnot_JSONARGPARSE_AVAILABLE:raiseModuleNotFoundError("`jsonargparse` is not installed but it is required for the CLI."" Install it with `pip install -U jsonargparse[signatures]`.")super().__init__(*args,**kwargs)self.add_argument("-c","--config",action=ActionConfigFile,help="Path to a configuration file in json or yaml format.")self.callback_keys:List[str]=[]# separate optimizers and lr schedulers to know which were addedself._optimizers:Dict[str,Tuple[Union[Type,Tuple[Type,...]],str]]={}self._lr_schedulers:Dict[str,Tuple[Union[Type,Tuple[Type,...]],str]]={}
[docs]defadd_lightning_class_args(self,lightning_class:Union[Callable[...,Union[Trainer,LightningModule,LightningDataModule,Callback]],Type[Trainer],Type[LightningModule],Type[LightningDataModule],Type[Callback],],nested_key:str,subclass_mode:bool=False,required:bool=True,)->List[str]:"""Adds arguments from a lightning class to a nested key of the parser. Args: lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. required: Whether the argument group is required. Returns: A list with the names of the class arguments added. """ifcallable(lightning_class)andnotisinstance(lightning_class,type):lightning_class=class_from_function(lightning_class)ifisinstance(lightning_class,type)andissubclass(lightning_class,(Trainer,LightningModule,LightningDataModule,Callback)):ifissubclass(lightning_class,Callback):self.callback_keys.append(nested_key)ifsubclass_mode:returnself.add_subclass_arguments(lightning_class,nested_key,fail_untyped=False,required=required)returnself.add_class_arguments(lightning_class,nested_key,fail_untyped=False,instantiate=notissubclass(lightning_class,Trainer),sub_configs=True,)raiseMisconfigurationException(f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: ""Trainer, LightningModule, LightningDataModule, or Callback.")
[docs]defadd_optimizer_args(self,optimizer_class:Union[Type[Optimizer],Tuple[Type[Optimizer],...]],nested_key:str="optimizer",link_to:str="AUTOMATIC",)->None:"""Adds arguments from an optimizer class to a nested key of the parser. Args: optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ifisinstance(optimizer_class,tuple):assertall(issubclass(o,Optimizer)foroinoptimizer_class)else:assertissubclass(optimizer_class,Optimizer)kwargs={"instantiate":False,"fail_untyped":False,"skip":{"params"}}ifisinstance(optimizer_class,tuple):self.add_subclass_arguments(optimizer_class,nested_key,**kwargs)self.set_choices(nested_key,optimizer_class)else:self.add_class_arguments(optimizer_class,nested_key,sub_configs=True,**kwargs)self._optimizers[nested_key]=(optimizer_class,link_to)
[docs]defadd_lr_scheduler_args(self,lr_scheduler_class:Union[LRSchedulerType,Tuple[LRSchedulerType,...]],nested_key:str="lr_scheduler",link_to:str="AUTOMATIC",)->None:"""Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ifisinstance(lr_scheduler_class,tuple):assertall(issubclass(o,LRSchedulerTypeTuple)foroinlr_scheduler_class)else:assertissubclass(lr_scheduler_class,LRSchedulerTypeTuple)kwargs={"instantiate":False,"fail_untyped":False,"skip":{"optimizer"}}ifisinstance(lr_scheduler_class,tuple):self.add_subclass_arguments(lr_scheduler_class,nested_key,**kwargs)self.set_choices(nested_key,lr_scheduler_class)else:self.add_class_arguments(lr_scheduler_class,nested_key,sub_configs=True,**kwargs)self._lr_schedulers[nested_key]=(lr_scheduler_class,link_to)
defparse_args(self,*args:Any,**kwargs:Any)->Dict[str,Any]:argv=sys.argvfork,vinself._choices.items():ifnotany(arg.startswith(f"--{k}")forarginargv):# the key wasn't passed - maybe defined in a config, maybe it's optionalcontinueclasses,is_list=v# knowing whether the argument is a list type automatically would be too complexifis_list:argv=self._convert_argv_issue_85(classes,k,argv)else:argv=self._convert_argv_issue_84(classes,k,argv)self._choices.clear()withmock.patch("sys.argv",argv):returnsuper().parse_args(*args,**kwargs)
[docs]defset_choices(self,nested_key:str,classes:Tuple[Type,...],is_list:bool=False)->None:"""Adds support for shorthand notation for a particular nested key. Args: nested_key: The key whose choices will be set. classes: A tuple of classes to choose from. is_list: Whether the argument is a ``List[object]`` type. """self._choices[nested_key]=(classes,is_list)
@staticmethoddef_convert_argv_issue_84(classes:Tuple[Type,...],nested_key:str,argv:List[str])->List[str]:"""Placeholder for https://github.com/omni-us/jsonargparse/issues/84. Adds support for shorthand notation for ``object`` arguments. """passed_args,clean_argv={},[]argv_key=f"--{nested_key}"# get the argv args for this nested keyi=0whilei<len(argv):arg=argv[i]ifarg.startswith(argv_key):if"="inarg:key,value=arg.split("=")else:key=argi+=1value=argv[i]passed_args[key]=valueelse:clean_argv.append(arg)i+=1# the user requested a help messagehelp_key=argv_key+".help"ifhelp_keyinpassed_args:argv_class=passed_args[help_key]if"."inargv_class:# user passed the class path directlyclass_path=argv_classelse:# convert shorthand format to the classpathforclsinclasses:ifcls.__name__==argv_class:class_path=_class_path_from_class(cls)breakelse:raiseValueError(f"Could not generate get the class_path for {repr(argv_class)}")returnclean_argv+[help_key,class_path]# generate the associated config fileargv_class=passed_args.pop(argv_key,"")ifnotargv_class:# the user passed a config as a strclass_path=passed_args[f"{argv_key}.class_path"]init_args_key=f"{argv_key}.init_args"init_args={k[len(init_args_key)+1:]:vfork,vinpassed_args.items()ifk.startswith(init_args_key)}config=str({"class_path":class_path,"init_args":init_args})elifargv_class.startswith("{")orargv_classin("None","True","False"):# the user passed a config as a dictconfig=argv_classelse:# the user passed the shorthand formatinit_args={k[len(argv_key)+1:]:vfork,vinpassed_args.items()}# +1 to account for the periodforclsinclasses:ifcls.__name__==argv_class:config=str(_global_add_class_path(cls,init_args))breakelse:raiseValueError(f"Could not generate a config for {repr(argv_class)}")returnclean_argv+[argv_key,config]@staticmethoddef_convert_argv_issue_85(classes:Tuple[Type,...],nested_key:str,argv:List[str])->List[str]:"""Placeholder for https://github.com/omni-us/jsonargparse/issues/85. Adds support for shorthand notation for ``List[object]`` arguments. """passed_args,clean_argv=[],[]passed_configs={}argv_key=f"--{nested_key}"# get the argv args for this nested keyi=0whilei<len(argv):arg=argv[i]ifarg.startswith(argv_key):if"="inarg:key,value=arg.split("=")else:key=argi+=1value=argv[i]if"class_path"invalue:# the user passed a config as a dictpassed_configs[key]=yaml.safe_load(value)else:passed_args.append((key,value))else:clean_argv.append(arg)i+=1# generate the associated config fileconfig=[]i,n=0,len(passed_args)whilei<n-1:ki,vi=passed_args[i]# convert class name to class pathforclsinclasses:ifcls.__name__==vi:cls_type=clsbreakelse:raiseValueError(f"Could not generate a config for {repr(vi)}")config.append(_global_add_class_path(cls_type))# get any init argsj=i+1# in case the j-loop doesn't runforjinrange(i+1,n):kj,vj=passed_args[j]ifki==kj:breakifkj.startswith(ki):init_arg_name=kj.split(".")[-1]config[-1]["init_args"][init_arg_name]=vji=j# update at the end to preserve the orderfork,vinpassed_configs.items():config.extend(v)ifnotconfig:returnclean_argvreturnclean_argv+[argv_key,str(config)]
[docs]classSaveConfigCallback(Callback):"""Saves a LightningCLI config to the log_dir when training starts. Args: parser: The parser object used to parse the configuration. config: The parsed configuration that will be saved. config_filename: Filename for the config file. overwrite: Whether to overwrite an existing config file. multifile: When input is multiple config files, saved config preserves this structure. Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run """def__init__(self,parser:LightningArgumentParser,config:Namespace,config_filename:str,overwrite:bool=False,multifile:bool=False,)->None:self.parser=parserself.config=configself.config_filename=config_filenameself.overwrite=overwriteself.multifile=multifile
[docs]defsetup(self,trainer:Trainer,pl_module:LightningModule,stage:Optional[str]=None)->None:log_dir=trainer.log_dir# this broadcasts the directoryassertlog_dirisnotNoneconfig_path=os.path.join(log_dir,self.config_filename)fs=get_filesystem(log_dir)ifnotself.overwrite:# check if the file exists on rank 0file_exists=fs.isfile(config_path)iftrainer.is_global_zeroelseFalse# broadcast whether to fail to all ranksfile_exists=trainer.strategy.broadcast(file_exists)iffile_exists:raiseRuntimeError(f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting"" results of a previous run. You can delete the previous config file,"" set `LightningCLI(save_config_callback=None)` to disable config saving,"" or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file.")# save the file on rank 0iftrainer.is_global_zero:# save only on rank zero to avoid race conditions.# the `log_dir` needs to be created as we rely on the logger to do it usually# but it hasn't logged anything at this pointfs.makedirs(log_dir,exist_ok=True)self.parser.save(self.config,config_path,skip_none=False,overwrite=self.overwrite,multifile=self.multifile)
[docs]classLightningCLI:"""Implementation of a configurable command line tool for pytorch-lightning."""def__init__(self,model_class:Optional[Union[Type[LightningModule],Callable[...,LightningModule]]]=None,datamodule_class:Optional[Union[Type[LightningDataModule],Callable[...,LightningDataModule]]]=None,save_config_callback:Optional[Type[SaveConfigCallback]]=SaveConfigCallback,save_config_filename:str="config.yaml",save_config_overwrite:bool=False,save_config_multifile:bool=False,trainer_class:Union[Type[Trainer],Callable[...,Trainer]]=Trainer,trainer_defaults:Optional[Dict[str,Any]]=None,seed_everything_default:Optional[int]=None,description:str="pytorch-lightning trainer command line tool",env_prefix:str="PL",env_parse:bool=False,parser_kwargs:Optional[Union[Dict[str,Any],Dict[str,Dict[str,Any]]]]=None,subclass_mode_model:bool=False,subclass_mode_data:bool=False,run:bool=True,auto_registry:bool=False,)->None:"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args. Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. For more info, read :ref:`the CLI docs <common/lightning_cli:LightningCLI>`. .. warning:: ``LightningCLI`` is in beta and subject to change. Args: model_class: An optional :class:`~pytorch_lightning.core.lightning.LightningModule` class to train on or a callable which returns a :class:`~pytorch_lightning.core.lightning.LightningModule` instance when called. If ``None``, you can pass a registered model with ``--model=MyModel``. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. save_config_callback: A callback class to save the training config. save_config_filename: Filename for the config file. save_config_overwrite: Whether to overwrite an existing config file. save_config_multifile: When input is multiple config files, saved config preserves this structure. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in :ref:`the CLI docs <common/lightning_cli:Configurable callbacks>`. seed_everything_default: Default value for the :func:`~pytorch_lightning.utilities.seed.seed_everything` seed argument. description: Description of the tool shown when running ``--help``. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. subclass_mode_model: Whether model can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. subclass_mode_data: Whether datamodule can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. auto_registry: Whether to automatically fill up the registries with all defined subclasses. """self.save_config_callback=save_config_callbackself.save_config_filename=save_config_filenameself.save_config_overwrite=save_config_overwriteself.save_config_multifile=save_config_multifileself.trainer_class=trainer_classself.trainer_defaults=trainer_defaultsor{}self.seed_everything_default=seed_everything_defaultself.model_class=model_class# used to differentiate between the original value and the processed valueself._model_class=model_classorLightningModuleself.subclass_mode_model=(model_classisNone)orsubclass_mode_modelself.datamodule_class=datamodule_class# used to differentiate between the original value and the processed valueself._datamodule_class=datamodule_classorLightningDataModuleself.subclass_mode_data=(datamodule_classisNone)orsubclass_mode_data_populate_registries(auto_registry)main_kwargs,subparser_kwargs=self._setup_parser_kwargs(parser_kwargsor{},# type: ignore # github.com/python/mypy/issues/6463{"description":description,"env_prefix":env_prefix,"default_env":env_parse},)self.setup_parser(run,main_kwargs,subparser_kwargs)self.parse_arguments(self.parser)self.subcommand=self.config["subcommand"]ifrunelseNoneseed=self._get(self.config,"seed_everything")ifseedisnotNone:seed_everything(seed,workers=True)self.before_instantiate_classes()self.instantiate_classes()ifself.subcommandisnotNone:self._run_subcommand(self.subcommand)def_setup_parser_kwargs(self,kwargs:Dict[str,Any],defaults:Dict[str,Any])->Tuple[Dict[str,Any],Dict[str,Any]]:ifkwargs.keys()&self.subcommands().keys():# `kwargs` contains arguments per subcommandreturndefaults,kwargsmain_kwargs=defaultsmain_kwargs.update(kwargs)returnmain_kwargs,{}
[docs]definit_parser(self,**kwargs:Any)->LightningArgumentParser:"""Method that instantiates the argument parser."""returnLightningArgumentParser(**kwargs)
[docs]defsetup_parser(self,add_subcommands:bool,main_kwargs:Dict[str,Any],subparser_kwargs:Dict[str,Any])->None:"""Initialize and setup the parser, subcommands, and arguments."""self.parser=self.init_parser(**main_kwargs)ifadd_subcommands:self._subcommand_method_arguments:Dict[str,List[str]]={}self._add_subcommands(self.parser,**subparser_kwargs)else:self._add_arguments(self.parser)
[docs]defadd_default_arguments_to_parser(self,parser:LightningArgumentParser)->None:"""Adds default arguments to the parser."""parser.add_argument("--seed_everything",type=Optional[int],default=self.seed_everything_default,help="Set to an int to run seed_everything with this value before classes instantiation",)
[docs]defadd_core_arguments_to_parser(self,parser:LightningArgumentParser)->None:"""Adds arguments from the core classes to the parser."""parser.add_lightning_class_args(self.trainer_class,"trainer")parser.set_choices("trainer.callbacks",CALLBACK_REGISTRY.classes,is_list=True)parser.set_choices("trainer.logger",LOGGER_REGISTRY.classes)trainer_defaults={"trainer."+k:vfork,vinself.trainer_defaults.items()ifk!="callbacks"}parser.set_defaults(trainer_defaults)parser.add_lightning_class_args(self._model_class,"model",subclass_mode=self.subclass_mode_model)ifself.model_classisNoneandlen(MODEL_REGISTRY):# did not pass a model and there are models registeredparser.set_choices("model",MODEL_REGISTRY.classes)ifself.datamodule_classisnotNone:parser.add_lightning_class_args(self._datamodule_class,"data",subclass_mode=self.subclass_mode_data)eliflen(DATAMODULE_REGISTRY):# this should not be required because the user might want to use the `LightningModule` dataloadersparser.add_lightning_class_args(self._datamodule_class,"data",subclass_mode=self.subclass_mode_data,required=False)parser.set_choices("data",DATAMODULE_REGISTRY.classes)
def_add_arguments(self,parser:LightningArgumentParser)->None:# default + core + custom argumentsself.add_default_arguments_to_parser(parser)self.add_core_arguments_to_parser(parser)self.add_arguments_to_parser(parser)# add default optimizer args if necessaryifnotparser._optimizers:# already added by the user in `add_arguments_to_parser`parser.add_optimizer_args(OPTIMIZER_REGISTRY.classes)ifnotparser._lr_schedulers:# already added by the user in `add_arguments_to_parser`parser.add_lr_scheduler_args(LR_SCHEDULER_REGISTRY.classes)self.link_optimizers_and_lr_schedulers(parser)
[docs]defadd_arguments_to_parser(self,parser:LightningArgumentParser)->None:"""Implement to add extra arguments to the parser or link arguments. Args: parser: The parser object to which arguments can be added """
[docs]@staticmethoddefsubcommands()->Dict[str,Set[str]]:"""Defines the list of available subcommands and the arguments to skip."""return{"fit":{"model","train_dataloaders","val_dataloaders","datamodule"},"validate":{"model","dataloaders","datamodule"},"test":{"model","dataloaders","datamodule"},"predict":{"model","dataloaders","datamodule"},"tune":{"model","train_dataloaders","val_dataloaders","datamodule"},}
def_add_subcommands(self,parser:LightningArgumentParser,**kwargs:Any)->None:"""Adds subcommands to the input parser."""parser_subcommands=parser.add_subcommands()# the user might have passed a builder functiontrainer_class=(self.trainer_classifisinstance(self.trainer_class,type)elseclass_from_function(self.trainer_class))# register all subcommands in separate subcommand parsers under the main parserforsubcommandinself.subcommands():subcommand_parser=self._prepare_subcommand_parser(trainer_class,subcommand,**kwargs.get(subcommand,{}))fn=getattr(trainer_class,subcommand)# extract the first line description in the docstring for the subcommand help messagedescription=_get_short_description(fn)parser_subcommands.add_subcommand(subcommand,subcommand_parser,help=description)def_prepare_subcommand_parser(self,klass:Type,subcommand:str,**kwargs:Any)->LightningArgumentParser:parser=self.init_parser(**kwargs)self._add_arguments(parser)# subcommand argumentsskip=self.subcommands()[subcommand]added=parser.add_method_arguments(klass,subcommand,skip=skip)# need to save which arguments were added to pass them to the method laterself._subcommand_method_arguments[subcommand]=addedreturnparser
[docs]@staticmethoddeflink_optimizers_and_lr_schedulers(parser:LightningArgumentParser)->None:"""Creates argument links for optimizers and learning rate schedulers that specified a ``link_to``."""optimizers_and_lr_schedulers={**parser._optimizers,**parser._lr_schedulers}forkey,(class_type,link_to)inoptimizers_and_lr_schedulers.items():iflink_to=="AUTOMATIC":continueifisinstance(class_type,tuple):parser.link_arguments(key,link_to)else:add_class_path=_add_class_path_generator(class_type)parser.link_arguments(key,link_to,compute_fn=add_class_path)
[docs]defparse_arguments(self,parser:LightningArgumentParser)->None:"""Parses command line arguments and stores it in ``self.config``."""self.config=parser.parse_args()
[docs]defbefore_instantiate_classes(self)->None:"""Implement to run some code before instantiating the classes."""
[docs]definstantiate_classes(self)->None:"""Instantiates the classes and sets their attributes."""self.config_init=self.parser.instantiate_classes(self.config)self.datamodule=self._get(self.config_init,"data")self.model=self._get(self.config_init,"model")self._add_configure_optimizers_method_to_model(self.subcommand)self.trainer=self.instantiate_trainer()
[docs]definstantiate_trainer(self,**kwargs:Any)->Trainer:"""Instantiates the trainer. Args: kwargs: Any custom trainer arguments. """extra_callbacks=[self._get(self.config_init,c)forcinself._parser(self.subcommand).callback_keys]trainer_config={**self._get(self.config_init,"trainer"),**kwargs}returnself._instantiate_trainer(trainer_config,extra_callbacks)
def_instantiate_trainer(self,config:Dict[str,Any],callbacks:List[Callback])->Trainer:config["callbacks"]=config["callbacks"]or[]config["callbacks"].extend(callbacks)if"callbacks"inself.trainer_defaults:ifisinstance(self.trainer_defaults["callbacks"],list):config["callbacks"].extend(self.trainer_defaults["callbacks"])else:config["callbacks"].append(self.trainer_defaults["callbacks"])ifself.save_config_callbackandnotconfig["fast_dev_run"]:config_callback=self.save_config_callback(self._parser(self.subcommand),self.config.get(str(self.subcommand),self.config),self.save_config_filename,overwrite=self.save_config_overwrite,multifile=self.save_config_multifile,)config["callbacks"].append(config_callback)returnself.trainer_class(**config)def_parser(self,subcommand:Optional[str])->LightningArgumentParser:ifsubcommandisNone:returnself.parser# return the subcommand parser for the subcommand passedaction_subcommand=self.parser._subcommands_actionreturnaction_subcommand._name_parser_map[subcommand]
[docs]@staticmethoddefconfigure_optimizers(lightning_module:LightningModule,optimizer:Optimizer,lr_scheduler:Optional[LRSchedulerTypeUnion]=None)->Any:"""Override to customize the :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` method. Args: lightning_module: A reference to the model. optimizer: The optimizer. lr_scheduler: The learning rate scheduler (if used). """iflr_schedulerisNone:returnoptimizerifisinstance(lr_scheduler,ReduceLROnPlateau):return{"optimizer":optimizer,"lr_scheduler":{"scheduler":lr_scheduler,"monitor":lr_scheduler.monitor},}return[optimizer],[lr_scheduler]
def_add_configure_optimizers_method_to_model(self,subcommand:Optional[str])->None:"""Overrides the model's :meth:`~pytorch_lightning.core.lightning.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'."""parser=self._parser(subcommand)defget_automatic(class_type:Union[Type,Tuple[Type,...]],register:Dict[str,Tuple[Union[Type,Tuple[Type,...]],str]])->List[str]:automatic=[]forkey,(base_class,link_to)inregister.items():ifnotisinstance(base_class,tuple):base_class=(base_class,)iflink_to=="AUTOMATIC"andany(issubclass(c,class_type)forcinbase_class):automatic.append(key)returnautomaticoptimizers=get_automatic(Optimizer,parser._optimizers)lr_schedulers=get_automatic(LRSchedulerTypeTuple,parser._lr_schedulers)iflen(optimizers)==0:returniflen(optimizers)>1orlen(lr_schedulers)>1:raiseMisconfigurationException(f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer "f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user ""is expected to link the argument groups and implement `configure_optimizers`, see ""https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html""#optimizers-and-learning-rate-schedulers")optimizer_class=parser._optimizers[optimizers[0]][0]optimizer_init=self._get(self.config_init,optimizers[0])ifnotisinstance(optimizer_class,tuple):optimizer_init=_global_add_class_path(optimizer_class,optimizer_init)ifnotoptimizer_init:# optimizers were registered automatically but not passed by the userreturnlr_scheduler_init=Noneiflr_schedulers:lr_scheduler_class=parser._lr_schedulers[lr_schedulers[0]][0]lr_scheduler_init=self._get(self.config_init,lr_schedulers[0])ifnotisinstance(lr_scheduler_class,tuple):lr_scheduler_init=_global_add_class_path(lr_scheduler_class,lr_scheduler_init)ifis_overridden("configure_optimizers",self.model):_warn(f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by "f"`{self.__class__.__name__}.configure_optimizers`.")optimizer=instantiate_class(self.model.parameters(),optimizer_init)lr_scheduler=instantiate_class(optimizer,lr_scheduler_init)iflr_scheduler_initelseNonefn=partial(self.configure_optimizers,optimizer=optimizer,lr_scheduler=lr_scheduler)update_wrapper(fn,self.configure_optimizers)# necessary for `is_overridden`# override the existing methodself.model.configure_optimizers=MethodType(fn,self.model)def_get(self,config:Dict[str,Any],key:str,default:Optional[Any]=None)->Any:"""Utility to get a config value which might be inside a subcommand."""returnconfig.get(str(self.subcommand),config).get(key,default)def_run_subcommand(self,subcommand:str)->None:"""Run the chosen subcommand."""before_fn=getattr(self,f"before_{subcommand}",None)ifcallable(before_fn):before_fn()default=getattr(self.trainer,subcommand)fn=getattr(self,subcommand,default)fn_kwargs=self._prepare_subcommand_kwargs(subcommand)fn(**fn_kwargs)after_fn=getattr(self,f"after_{subcommand}",None)ifcallable(after_fn):after_fn()def_prepare_subcommand_kwargs(self,subcommand:str)->Dict[str,Any]:"""Prepares the keyword arguments to pass to the subcommand to run."""fn_kwargs={k:vfork,vinself.config_init[subcommand].items()ifkinself._subcommand_method_arguments[subcommand]}fn_kwargs["model"]=self.modelifself.datamoduleisnotNone:fn_kwargs["datamodule"]=self.datamodulereturnfn_kwargs
[docs]definstantiate_class(args:Union[Any,Tuple[Any,...]],init:Dict[str,Any])->Any:"""Instantiates a class with the given args and init. Args: args: Positional arguments required for instantiation. init: Dict of the form {"class_path":...,"init_args":...}. Returns: The instantiated class object. """kwargs=init.get("init_args",{})ifnotisinstance(args,tuple):args=(args,)class_module,class_name=init["class_path"].rsplit(".",1)module=__import__(class_module,fromlist=[class_name])args_class=getattr(module,class_name)returnargs_class(*args,**kwargs)
def_get_short_description(component:object)->Optional[str]:parse,_=import_docstring_parse("LightningCLI(run=True)")try:docstring=parse(component.__doc__)returndocstring.short_descriptionexceptValueError:rank_zero_warn(f"Failed parsing docstring for {component}")
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.