Shortcuts

Source code for pytorch_lightning.cli

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from functools import partial, update_wrapper
from types import MethodType
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union

import torch
from lightning_utilities.core.imports import RequirementCache
from lightning_utilities.core.rank_zero import _warn
from torch.optim import Optimizer

import pytorch_lightning as pl
from lightning_lite.utilities.cloud_io import get_filesystem
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = RequirementCache("jsonargparse[signatures]>=4.15.2")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
    import docstring_parser
    from jsonargparse import (
        ActionConfigFile,
        ArgumentParser,
        class_from_function,
        Namespace,
        register_unresolvable_import_paths,
        set_config_read_mode,
    )

    register_unresolvable_import_paths(torch)  # Required until fix https://github.com/pytorch/pytorch/issues/74483
    set_config_read_mode(fsspec_enabled=True)
else:
    locals()["ArgumentParser"] = object
    locals()["Namespace"] = object


ArgsType = Optional[Union[List[str], Dict[str, Any], Namespace]]


class ReduceLROnPlateau(torch.optim.lr_scheduler.ReduceLROnPlateau):
    def __init__(self, optimizer: Optimizer, monitor: str, *args: Any, **kwargs: Any) -> None:
        super().__init__(optimizer, *args, **kwargs)
        self.monitor = monitor


# LightningCLI requires the ReduceLROnPlateau defined here, thus it shouldn't accept the one from pytorch:
LRSchedulerTypeTuple = (torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau)
LRSchedulerTypeUnion = Union[torch.optim.lr_scheduler._LRScheduler, ReduceLROnPlateau]
LRSchedulerType = Union[Type[torch.optim.lr_scheduler._LRScheduler], Type[ReduceLROnPlateau]]


[docs]class LightningArgumentParser(ArgumentParser): """Extension of jsonargparse's ArgumentParser for pytorch-lightning.""" def __init__(self, *args: Any, **kwargs: Any) -> None: """Initialize argument parser that supports configuration file input. For full details of accepted arguments see `ArgumentParser.__init__ <https://jsonargparse.readthedocs.io/en/stable/index.html#jsonargparse.ArgumentParser.__init__>`_. """ if not _JSONARGPARSE_SIGNATURES_AVAILABLE: raise ModuleNotFoundError( f"{_JSONARGPARSE_SIGNATURES_AVAILABLE}. Try `pip install -U 'jsonargparse[signatures]'`." ) super().__init__(*args, **kwargs) self.callback_keys: List[str] = [] # separate optimizers and lr schedulers to know which were added self._optimizers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {} self._lr_schedulers: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] = {}
[docs] def add_lightning_class_args( self, lightning_class: Union[ Callable[..., Union[Trainer, LightningModule, LightningDataModule, Callback]], Type[Trainer], Type[LightningModule], Type[LightningDataModule], Type[Callback], ], nested_key: str, subclass_mode: bool = False, required: bool = True, ) -> List[str]: """Adds arguments from a lightning class to a nested key of the parser. Args: lightning_class: A callable or any subclass of {Trainer, LightningModule, LightningDataModule, Callback}. nested_key: Name of the nested namespace to store arguments. subclass_mode: Whether allow any subclass of the given class. required: Whether the argument group is required. Returns: A list with the names of the class arguments added. """ if callable(lightning_class) and not isinstance(lightning_class, type): lightning_class = class_from_function(lightning_class) if isinstance(lightning_class, type) and issubclass( lightning_class, (Trainer, LightningModule, LightningDataModule, Callback) ): if issubclass(lightning_class, Callback): self.callback_keys.append(nested_key) if subclass_mode: return self.add_subclass_arguments(lightning_class, nested_key, fail_untyped=False, required=required) return self.add_class_arguments( lightning_class, nested_key, fail_untyped=False, instantiate=not issubclass(lightning_class, Trainer), sub_configs=True, ) raise MisconfigurationException( f"Cannot add arguments from: {lightning_class}. You should provide either a callable or a subclass of: " "Trainer, LightningModule, LightningDataModule, or Callback." )
[docs] def add_optimizer_args( self, optimizer_class: Union[Type[Optimizer], Tuple[Type[Optimizer], ...]] = (Optimizer,), nested_key: str = "optimizer", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from an optimizer class to a nested key of the parser. Args: optimizer_class: Any subclass of :class:`torch.optim.Optimizer`. Use tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ if isinstance(optimizer_class, tuple): assert all(issubclass(o, Optimizer) for o in optimizer_class) else: assert issubclass(optimizer_class, Optimizer) kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"params"}} if isinstance(optimizer_class, tuple): self.add_subclass_arguments(optimizer_class, nested_key, **kwargs) else: self.add_class_arguments(optimizer_class, nested_key, sub_configs=True, **kwargs) self._optimizers[nested_key] = (optimizer_class, link_to)
[docs] def add_lr_scheduler_args( self, lr_scheduler_class: Union[LRSchedulerType, Tuple[LRSchedulerType, ...]] = LRSchedulerTypeTuple, nested_key: str = "lr_scheduler", link_to: str = "AUTOMATIC", ) -> None: """Adds arguments from a learning rate scheduler class to a nested key of the parser. Args: lr_scheduler_class: Any subclass of ``torch.optim.lr_scheduler.{_LRScheduler, ReduceLROnPlateau}``. Use tuple to allow subclasses. nested_key: Name of the nested namespace to store arguments. link_to: Dot notation of a parser key to set arguments or AUTOMATIC. """ if isinstance(lr_scheduler_class, tuple): assert all(issubclass(o, LRSchedulerTypeTuple) for o in lr_scheduler_class) else: assert issubclass(lr_scheduler_class, LRSchedulerTypeTuple) kwargs: Dict[str, Any] = {"instantiate": False, "fail_untyped": False, "skip": {"optimizer"}} if isinstance(lr_scheduler_class, tuple): self.add_subclass_arguments(lr_scheduler_class, nested_key, **kwargs) else: self.add_class_arguments(lr_scheduler_class, nested_key, sub_configs=True, **kwargs) self._lr_schedulers[nested_key] = (lr_scheduler_class, link_to)
[docs]class SaveConfigCallback(Callback): """Saves a LightningCLI config to the log_dir when training starts. Args: parser: The parser object used to parse the configuration. config: The parsed configuration that will be saved. config_filename: Filename for the config file. overwrite: Whether to overwrite an existing config file. multifile: When input is multiple config files, saved config preserves this structure. Raises: RuntimeError: If the config file already exists in the directory to avoid overwriting a previous run """ def __init__( self, parser: LightningArgumentParser, config: Namespace, config_filename: str = "config.yaml", overwrite: bool = False, multifile: bool = False, ) -> None: self.parser = parser self.config = config self.config_filename = config_filename self.overwrite = overwrite self.multifile = multifile self.already_saved = False
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: if self.already_saved: return log_dir = trainer.log_dir # this broadcasts the directory assert log_dir is not None config_path = os.path.join(log_dir, self.config_filename) fs = get_filesystem(log_dir) if not self.overwrite: # check if the file exists on rank 0 file_exists = fs.isfile(config_path) if trainer.is_global_zero else False # broadcast whether to fail to all ranks file_exists = trainer.strategy.broadcast(file_exists) if file_exists: raise RuntimeError( f"{self.__class__.__name__} expected {config_path} to NOT exist. Aborting to avoid overwriting" " results of a previous run. You can delete the previous config file," " set `LightningCLI(save_config_callback=None)` to disable config saving," " or set `LightningCLI(save_config_overwrite=True)` to overwrite the config file." ) # save the file on rank 0 if trainer.is_global_zero: # save only on rank zero to avoid race conditions. # the `log_dir` needs to be created as we rely on the logger to do it usually # but it hasn't logged anything at this point fs.makedirs(log_dir, exist_ok=True) self.parser.save( self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile ) self.already_saved = True # broadcast so that all ranks are in sync on future calls to .setup() self.already_saved = trainer.strategy.broadcast(self.already_saved)
[docs]class LightningCLI: """Implementation of a configurable command line tool for pytorch-lightning.""" def __init__( self, model_class: Optional[Union[Type[LightningModule], Callable[..., LightningModule]]] = None, datamodule_class: Optional[Union[Type[LightningDataModule], Callable[..., LightningDataModule]]] = None, save_config_callback: Optional[Type[SaveConfigCallback]] = SaveConfigCallback, save_config_kwargs: Optional[Dict[str, Any]] = None, trainer_class: Union[Type[Trainer], Callable[..., Trainer]] = Trainer, trainer_defaults: Optional[Dict[str, Any]] = None, seed_everything_default: Union[bool, int] = True, description: str = "pytorch-lightning trainer command line tool", env_prefix: str = "PL", env_parse: bool = False, parser_kwargs: Optional[Union[Dict[str, Any], Dict[str, Dict[str, Any]]]] = None, subclass_mode_model: bool = False, subclass_mode_data: bool = False, args: ArgsType = None, run: bool = True, auto_registry: bool = False, **kwargs: Any, # Remove with deprecations of v1.10 ) -> None: """Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are called / instantiated using a parsed configuration file and / or command line args. Parsing of configuration from environment variables can be enabled by setting ``env_parse=True``. A full configuration yaml would be parsed from ``PL_CONFIG`` if set. Individual settings are so parsed from variables named for example ``PL_TRAINER__MAX_EPOCHS``. For more info, read :ref:`the CLI docs <lightning-cli>`. .. warning:: ``LightningCLI`` is in beta and subject to change. Args: model_class: An optional :class:`~pytorch_lightning.core.module.LightningModule` class to train on or a callable which returns a :class:`~pytorch_lightning.core.module.LightningModule` instance when called. If ``None``, you can pass a registered model with ``--model=MyModel``. datamodule_class: An optional :class:`~pytorch_lightning.core.datamodule.LightningDataModule` class or a callable which returns a :class:`~pytorch_lightning.core.datamodule.LightningDataModule` instance when called. If ``None``, you can pass a registered datamodule with ``--data=MyDataModule``. save_config_callback: A callback class to save the config. save_config_kwargs: Parameters that will be used to instantiate the save_config_callback. trainer_class: An optional subclass of the :class:`~pytorch_lightning.trainer.trainer.Trainer` class or a callable which returns a :class:`~pytorch_lightning.trainer.trainer.Trainer` instance when called. trainer_defaults: Set to override Trainer defaults or add persistent callbacks. The callbacks added through this argument will not be configurable from a configuration file and will always be present for this particular CLI. Alternatively, configurable callbacks can be added as explained in :ref:`the CLI docs <lightning-cli>`. seed_everything_default: Number for the :func:`~lightning_lite.utilities.seed.seed_everything` seed value. Set to True to automatically choose a seed value. Setting it to False will avoid calling ``seed_everything``. description: Description of the tool shown when running ``--help``. env_prefix: Prefix for environment variables. env_parse: Whether environment variable parsing is enabled. parser_kwargs: Additional arguments to instantiate each ``LightningArgumentParser``. subclass_mode_model: Whether model can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. subclass_mode_data: Whether datamodule can be any `subclass <https://jsonargparse.readthedocs.io/en/stable/#class-type-and-sub-classes>`_ of the given class. args: Arguments to parse. If ``None`` the arguments are taken from ``sys.argv``. Command line style arguments can be given in a ``list``. Alternatively, structured config options can be given in a ``dict`` or ``jsonargparse.Namespace``. run: Whether subcommands should be added to run a :class:`~pytorch_lightning.trainer.trainer.Trainer` method. If set to ``False``, the trainer and model classes will be instantiated only. auto_registry: Whether to automatically fill up the registries with all defined subclasses. """ self.save_config_callback = save_config_callback self.save_config_kwargs = save_config_kwargs or {} self.trainer_class = trainer_class self.trainer_defaults = trainer_defaults or {} self.seed_everything_default = seed_everything_default self._handle_deprecated_params(kwargs) self.model_class = model_class # used to differentiate between the original value and the processed value self._model_class = model_class or LightningModule self.subclass_mode_model = (model_class is None) or subclass_mode_model self.datamodule_class = datamodule_class # used to differentiate between the original value and the processed value self._datamodule_class = datamodule_class or LightningDataModule self.subclass_mode_data = (datamodule_class is None) or subclass_mode_data from pytorch_lightning.utilities.cli import _populate_registries _populate_registries(auto_registry) main_kwargs, subparser_kwargs = self._setup_parser_kwargs( parser_kwargs or {}, # type: ignore # github.com/python/mypy/issues/6463 {"description": description, "env_prefix": env_prefix, "default_env": env_parse}, ) self.setup_parser(run, main_kwargs, subparser_kwargs) self.parse_arguments(self.parser, args) self.subcommand = self.config["subcommand"] if run else None self._set_seed() self.before_instantiate_classes() self.instantiate_classes() if self.subcommand is not None: self._run_subcommand(self.subcommand) def _handle_deprecated_params(self, kwargs: dict) -> None: if self.seed_everything_default is None: rank_zero_deprecation( "Setting `LightningCLI.seed_everything_default` to `None` is deprecated in v1.7 " "and will be removed in v1.9. Set it to `False` instead." ) self.seed_everything_default = False for name in ["save_config_filename", "save_config_overwrite", "save_config_multifile"]: if name in kwargs: value = kwargs.pop(name) key = name.replace("save_config_", "").replace("filename", "config_filename") self.save_config_kwargs[key] = value rank_zero_deprecation( f"LightningCLI's {name!r} init parameter is deprecated from v1.8 " "and will be removed in v1.10. Use 'save_config_kwargs' instead." ) if kwargs: raise ValueError(f"Unexpected keyword parameters: {kwargs}") def _setup_parser_kwargs( self, kwargs: Dict[str, Any], defaults: Dict[str, Any] ) -> Tuple[Dict[str, Any], Dict[str, Any]]: if kwargs.keys() & self.subcommands().keys(): # `kwargs` contains arguments per subcommand return defaults, kwargs main_kwargs = defaults main_kwargs.update(kwargs) return main_kwargs, {}
[docs] def init_parser(self, **kwargs: Any) -> LightningArgumentParser: """Method that instantiates the argument parser.""" kwargs.setdefault("dump_header", [f"pytorch_lightning=={pl.__version__}"]) parser = LightningArgumentParser(**kwargs) parser.add_argument( "-c", "--config", action=ActionConfigFile, help="Path to a configuration file in json or yaml format." ) return parser
[docs] def setup_parser( self, add_subcommands: bool, main_kwargs: Dict[str, Any], subparser_kwargs: Dict[str, Any] ) -> None: """Initialize and setup the parser, subcommands, and arguments.""" self.parser = self.init_parser(**main_kwargs) if add_subcommands: self._subcommand_method_arguments: Dict[str, List[str]] = {} self._add_subcommands(self.parser, **subparser_kwargs) else: self._add_arguments(self.parser)
[docs] def add_default_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds default arguments to the parser.""" parser.add_argument( "--seed_everything", type=Union[bool, int], default=self.seed_everything_default, help=( "Set to an int to run seed_everything with this value before classes instantiation." "Set to True to use a random seed." ), )
[docs] def add_core_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Adds arguments from the core classes to the parser.""" parser.add_lightning_class_args(self.trainer_class, "trainer") trainer_defaults = {"trainer." + k: v for k, v in self.trainer_defaults.items() if k != "callbacks"} parser.set_defaults(trainer_defaults) parser.add_lightning_class_args(self._model_class, "model", subclass_mode=self.subclass_mode_model) if self.datamodule_class is not None: parser.add_lightning_class_args(self._datamodule_class, "data", subclass_mode=self.subclass_mode_data) else: # this should not be required because the user might want to use the `LightningModule` dataloaders parser.add_lightning_class_args( self._datamodule_class, "data", subclass_mode=self.subclass_mode_data, required=False )
def _add_arguments(self, parser: LightningArgumentParser) -> None: # default + core + custom arguments self.add_default_arguments_to_parser(parser) self.add_core_arguments_to_parser(parser) self.add_arguments_to_parser(parser) # add default optimizer args if necessary if not parser._optimizers: # already added by the user in `add_arguments_to_parser` parser.add_optimizer_args((Optimizer,)) if not parser._lr_schedulers: # already added by the user in `add_arguments_to_parser` parser.add_lr_scheduler_args(LRSchedulerTypeTuple) self.link_optimizers_and_lr_schedulers(parser)
[docs] def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: """Implement to add extra arguments to the parser or link arguments. Args: parser: The parser object to which arguments can be added """
[docs] @staticmethod def subcommands() -> Dict[str, Set[str]]: """Defines the list of available subcommands and the arguments to skip.""" return { "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, "validate": {"model", "dataloaders", "datamodule"}, "test": {"model", "dataloaders", "datamodule"}, "predict": {"model", "dataloaders", "datamodule"}, "tune": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, }
def _add_subcommands(self, parser: LightningArgumentParser, **kwargs: Any) -> None: """Adds subcommands to the input parser.""" self._subcommand_parsers: Dict[str, LightningArgumentParser] = {} parser_subcommands = parser.add_subcommands() # the user might have passed a builder function trainer_class = ( self.trainer_class if isinstance(self.trainer_class, type) else class_from_function(self.trainer_class) ) # register all subcommands in separate subcommand parsers under the main parser for subcommand in self.subcommands(): fn = getattr(trainer_class, subcommand) # extract the first line description in the docstring for the subcommand help message description = _get_short_description(fn) subparser_kwargs = kwargs.get(subcommand, {}) subparser_kwargs.setdefault("description", description) subcommand_parser = self._prepare_subcommand_parser(trainer_class, subcommand, **subparser_kwargs) self._subcommand_parsers[subcommand] = subcommand_parser parser_subcommands.add_subcommand(subcommand, subcommand_parser, help=description) def _prepare_subcommand_parser(self, klass: Type, subcommand: str, **kwargs: Any) -> LightningArgumentParser: parser = self.init_parser(**kwargs) self._add_arguments(parser) # subcommand arguments skip = self.subcommands()[subcommand] added = parser.add_method_arguments(klass, subcommand, skip=skip) # need to save which arguments were added to pass them to the method later self._subcommand_method_arguments[subcommand] = added return parser
[docs] def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> None: """Parses command line arguments and stores it in ``self.config``.""" if args is not None and len(sys.argv) > 1: raise ValueError( "LightningCLI's args parameter is intended to run from within Python like if it were from the command " "line. To prevent mistakes it is not allowed to provide both args and command line arguments, got: " f"sys.argv[1:]={sys.argv[1:]}, args={args}." ) if isinstance(args, (dict, Namespace)): self.config = parser.parse_object(args) else: self.config = parser.parse_args(args)
[docs] def before_instantiate_classes(self) -> None: """Implement to run some code before instantiating the classes."""
[docs] def instantiate_classes(self) -> None: """Instantiates the classes and sets their attributes.""" self.config_init = self.parser.instantiate_classes(self.config) self.datamodule = self._get(self.config_init, "data") self.model = self._get(self.config_init, "model") self._add_configure_optimizers_method_to_model(self.subcommand) self.trainer = self.instantiate_trainer()
[docs] def instantiate_trainer(self, **kwargs: Any) -> Trainer: """Instantiates the trainer. Args: kwargs: Any custom trainer arguments. """ extra_callbacks = [self._get(self.config_init, c) for c in self._parser(self.subcommand).callback_keys] trainer_config = {**self._get(self.config_init, "trainer", default={}), **kwargs} return self._instantiate_trainer(trainer_config, extra_callbacks)
def _instantiate_trainer(self, config: Dict[str, Any], callbacks: List[Callback]) -> Trainer: key = "callbacks" if key in config: if config[key] is None: config[key] = [] elif not isinstance(config[key], list): config[key] = [config[key]] config[key].extend(callbacks) if key in self.trainer_defaults: value = self.trainer_defaults[key] config[key] += value if isinstance(value, list) else [value] if self.save_config_callback and not config.get("fast_dev_run", False): config_callback = self.save_config_callback( self._parser(self.subcommand), self.config.get(str(self.subcommand), self.config), **self.save_config_kwargs, ) config[key].append(config_callback) else: rank_zero_warn( f"The `{self.trainer_class.__qualname__}` class does not expose the `{key}` argument so they will" " not be included." ) return self.trainer_class(**config) def _parser(self, subcommand: Optional[str]) -> LightningArgumentParser: if subcommand is None: return self.parser # return the subcommand parser for the subcommand passed return self._subcommand_parsers[subcommand]
[docs] @staticmethod def configure_optimizers( lightning_module: LightningModule, optimizer: Optimizer, lr_scheduler: Optional[LRSchedulerTypeUnion] = None ) -> Any: """Override to customize the :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method. Args: lightning_module: A reference to the model. optimizer: The optimizer. lr_scheduler: The learning rate scheduler (if used). """ if lr_scheduler is None: return optimizer if isinstance(lr_scheduler, ReduceLROnPlateau): return { "optimizer": optimizer, "lr_scheduler": {"scheduler": lr_scheduler, "monitor": lr_scheduler.monitor}, } return [optimizer], [lr_scheduler]
def _add_configure_optimizers_method_to_model(self, subcommand: Optional[str]) -> None: """Overrides the model's :meth:`~pytorch_lightning.core.module.LightningModule.configure_optimizers` method if a single optimizer and optionally a scheduler argument groups are added to the parser as 'AUTOMATIC'.""" parser = self._parser(subcommand) def get_automatic( class_type: Union[Type, Tuple[Type, ...]], register: Dict[str, Tuple[Union[Type, Tuple[Type, ...]], str]] ) -> List[str]: automatic = [] for key, (base_class, link_to) in register.items(): if not isinstance(base_class, tuple): base_class = (base_class,) if link_to == "AUTOMATIC" and any(issubclass(c, class_type) for c in base_class): automatic.append(key) return automatic optimizers = get_automatic(Optimizer, parser._optimizers) lr_schedulers = get_automatic(LRSchedulerTypeTuple, parser._lr_schedulers) if len(optimizers) == 0: return if len(optimizers) > 1 or len(lr_schedulers) > 1: raise MisconfigurationException( f"`{self.__class__.__name__}.add_configure_optimizers_method_to_model` expects at most one optimizer " f"and one lr_scheduler to be 'AUTOMATIC', but found {optimizers+lr_schedulers}. In this case the user " "is expected to link the argument groups and implement `configure_optimizers`, see " "https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_cli.html" "#optimizers-and-learning-rate-schedulers" ) optimizer_class = parser._optimizers[optimizers[0]][0] optimizer_init = self._get(self.config_init, optimizers[0]) if not isinstance(optimizer_class, tuple): optimizer_init = _global_add_class_path(optimizer_class, optimizer_init) if not optimizer_init: # optimizers were registered automatically but not passed by the user return lr_scheduler_init = None if lr_schedulers: lr_scheduler_class = parser._lr_schedulers[lr_schedulers[0]][0] lr_scheduler_init = self._get(self.config_init, lr_schedulers[0]) if not isinstance(lr_scheduler_class, tuple): lr_scheduler_init = _global_add_class_path(lr_scheduler_class, lr_scheduler_init) if is_overridden("configure_optimizers", self.model): _warn( f"`{self.model.__class__.__name__}.configure_optimizers` will be overridden by " f"`{self.__class__.__name__}.configure_optimizers`." ) optimizer = instantiate_class(self.model.parameters(), optimizer_init) lr_scheduler = instantiate_class(optimizer, lr_scheduler_init) if lr_scheduler_init else None fn = partial(self.configure_optimizers, optimizer=optimizer, lr_scheduler=lr_scheduler) update_wrapper(fn, self.configure_optimizers) # necessary for `is_overridden` # override the existing method self.model.configure_optimizers = MethodType(fn, self.model) def _get(self, config: Namespace, key: str, default: Optional[Any] = None) -> Any: """Utility to get a config value which might be inside a subcommand.""" return config.get(str(self.subcommand), config).get(key, default) def _run_subcommand(self, subcommand: str) -> None: """Run the chosen subcommand.""" before_fn = getattr(self, f"before_{subcommand}", None) if callable(before_fn): before_fn() default = getattr(self.trainer, subcommand) fn = getattr(self, subcommand, default) fn_kwargs = self._prepare_subcommand_kwargs(subcommand) fn(**fn_kwargs) after_fn = getattr(self, f"after_{subcommand}", None) if callable(after_fn): after_fn() def _prepare_subcommand_kwargs(self, subcommand: str) -> Dict[str, Any]: """Prepares the keyword arguments to pass to the subcommand to run.""" fn_kwargs = { k: v for k, v in self.config_init[subcommand].items() if k in self._subcommand_method_arguments[subcommand] } fn_kwargs["model"] = self.model if self.datamodule is not None: fn_kwargs["datamodule"] = self.datamodule return fn_kwargs def _set_seed(self) -> None: """Sets the seed.""" config_seed = self._get(self.config, "seed_everything") if config_seed is False: return if config_seed is True: # user requested seeding, choose randomly config_seed = seed_everything(workers=True) else: config_seed = seed_everything(config_seed, workers=True) self.config["seed_everything"] = config_seed
def _class_path_from_class(class_type: Type) -> str: return class_type.__module__ + "." + class_type.__name__ def _global_add_class_path( class_type: Type, init_args: Optional[Union[Namespace, Dict[str, Any]]] = None ) -> Dict[str, Any]: if isinstance(init_args, Namespace): init_args = init_args.as_dict() return {"class_path": _class_path_from_class(class_type), "init_args": init_args or {}} def _add_class_path_generator(class_type: Type) -> Callable[[Namespace], Dict[str, Any]]: def add_class_path(init_args: Namespace) -> Dict[str, Any]: return _global_add_class_path(class_type, init_args) return add_class_path def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: """Instantiates a class with the given args and init. Args: args: Positional arguments required for instantiation. init: Dict of the form {"class_path":...,"init_args":...}. Returns: The instantiated class object. """ kwargs = init.get("init_args", {}) if not isinstance(args, tuple): args = (args,) class_module, class_name = init["class_path"].rsplit(".", 1) module = __import__(class_module, fromlist=[class_name]) args_class = getattr(module, class_name) return args_class(*args, **kwargs) def _get_short_description(component: object) -> Optional[str]: if component.__doc__ is None: return None try: docstring = docstring_parser.parse(component.__doc__) return docstring.short_description except (ValueError, docstring_parser.ParseError) as ex: rank_zero_warn(f"Failed parsing docstring for {component}: {ex}")

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.