Source code for pytorch_lightning.utilities.cli
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Deprecated utilities for LightningCLI."""
import inspect
from types import ModuleType
from typing import Any, Generator, List, Optional, Tuple, Type
import torch
from lightning_utilities.core.inheritance import get_all_subclasses
from torch.optim import Optimizer
import pytorch_lightning as pl
import pytorch_lightning.cli as new_cli
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
_deprecate_registry_message = (
"`LightningCLI`'s registries were deprecated in v1.7 and will be removed "
"in v1.9. Now any imported subclass is automatically available by name in "
"`LightningCLI` without any need to explicitly register it."
)
_deprecate_auto_registry_message = (
"`LightningCLI.auto_registry` parameter was deprecated in v1.7 and will be removed "
"in v1.9. Now any imported subclass is automatically available by name in "
"`LightningCLI` without any need to explicitly register it."
)
class _Registry(dict): # Remove in v1.9
def __call__(
self, cls: Type, key: Optional[str] = None, override: bool = False, show_deprecation: bool = True
) -> Type:
"""Registers a class mapped to a name.
Args:
cls: the class to be mapped.
key: the name that identifies the provided class.
override: Whether to override an existing key.
"""
if key is None:
key = cls.__name__
elif not isinstance(key, str):
raise TypeError(f"`key` must be a str, found {key}")
if key not in self or override:
self[key] = cls
self._deprecation(show_deprecation)
return cls
def register_classes(
self, module: ModuleType, base_cls: Type, override: bool = False, show_deprecation: bool = True
) -> None:
"""This function is an utility to register all classes from a module."""
for cls in self.get_members(module, base_cls):
self(cls=cls, override=override, show_deprecation=show_deprecation)
@staticmethod
def get_members(module: ModuleType, base_cls: Type) -> Generator[Type, None, None]:
return (
cls
for _, cls in inspect.getmembers(module, predicate=inspect.isclass)
if issubclass(cls, base_cls) and cls != base_cls
)
@property
def names(self) -> List[str]:
"""Returns the registered names."""
self._deprecation()
return list(self.keys())
@property
def classes(self) -> Tuple[Type, ...]:
"""Returns the registered classes."""
self._deprecation()
return tuple(self.values())
def __str__(self) -> str:
return f"Registered objects: {self.names}"
def _deprecation(self, show_deprecation: bool = True) -> None:
if show_deprecation and not getattr(self, "deprecation_shown", False):
rank_zero_deprecation(_deprecate_registry_message)
self.deprecation_shown = True
OPTIMIZER_REGISTRY = _Registry()
LR_SCHEDULER_REGISTRY = _Registry()
CALLBACK_REGISTRY = _Registry()
MODEL_REGISTRY = _Registry()
DATAMODULE_REGISTRY = _Registry()
LOGGER_REGISTRY = _Registry()
def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
if subclasses:
rank_zero_deprecation(_deprecate_auto_registry_message)
# this will register any subclasses from all loaded modules including userland
for cls in get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningModule):
MODEL_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningDataModule):
DATAMODULE_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.loggers.Logger):
LOGGER_REGISTRY(cls, show_deprecation=False)
else:
# manually register torch's subclasses and our subclasses
OPTIMIZER_REGISTRY.register_classes(torch.optim, Optimizer, show_deprecation=False)
LR_SCHEDULER_REGISTRY.register_classes(
torch.optim.lr_scheduler, torch.optim.lr_scheduler._LRScheduler, show_deprecation=False
)
CALLBACK_REGISTRY.register_classes(pl.callbacks, pl.Callback, show_deprecation=False)
LOGGER_REGISTRY.register_classes(pl.loggers, pl.loggers.Logger, show_deprecation=False)
# `ReduceLROnPlateau` does not subclass `_LRScheduler`
LR_SCHEDULER_REGISTRY(cls=new_cli.ReduceLROnPlateau, show_deprecation=False)
def _deprecation(cls: Type) -> None:
rank_zero_deprecation(
f"`pytorch_lightning.utilities.cli.{cls.__name__}` has been deprecated in v1.7 and will be removed in v1.9."
f" Use the equivalent class in `pytorch_lightning.cli.{cls.__name__}` instead."
)
[docs]class LightningArgumentParser(new_cli.LightningArgumentParser):
def __init__(self, *args: Any, **kwargs: Any) -> None:
_deprecation(type(self))
super().__init__(*args, **kwargs)
[docs]class SaveConfigCallback(new_cli.SaveConfigCallback):
def __init__(self, *args: Any, **kwargs: Any) -> None:
_deprecation(type(self))
super().__init__(*args, **kwargs)
[docs]class LightningCLI(new_cli.LightningCLI):
def __init__(self, *args: Any, **kwargs: Any) -> None:
_deprecation(type(self))
super().__init__(*args, **kwargs)
def instantiate_class(*args: Any, **kwargs: Any) -> Any:
rank_zero_deprecation(
"`pytorch_lightning.utilities.cli.instantiate_class` has been deprecated in v1.7 and will be removed in v1.9."
" Use the equivalent function in `pytorch_lightning.cli.instantiate_class` instead."
)
return new_cli.instantiate_class(*args, **kwargs)