Shortcuts

Source code for lightning.pytorch.core.mixins.hparams_mixin

# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import types
from argparse import Namespace
from typing import Any, List, MutableMapping, Optional, Sequence, Union

from lightning.pytorch.utilities.parsing import AttributeDict, save_hyperparameters

_PRIMITIVE_TYPES = (bool, int, float, str)
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)


[docs]class HyperparametersMixin: __jit_unused_properties__: List[str] = ["hparams", "hparams_initial"] def __init__(self) -> None: super().__init__() self._log_hyperparams = False
[docs] def save_hyperparameters( self, *args: Any, ignore: Optional[Union[Sequence[str], str]] = None, frame: Optional[types.FrameType] = None, logger: bool = True, ) -> None: """Save arguments to ``hparams`` attribute. Args: args: single object of `dict`, `NameSpace` or `OmegaConf` or string names or arguments from class ``__init__`` ignore: an argument name or a list of argument names from class ``__init__`` to be ignored frame: a frame object. Default is None logger: Whether to send the hyperparameters to the logger. Default: True Example:: >>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # manually assign arguments ... self.save_hyperparameters('arg1', 'arg3') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14 >>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class AutomaticArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # equivalent automatic ... self.save_hyperparameters() ... def forward(self, *args, **kwargs): ... ... >>> model = AutomaticArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg2": abc "arg3": 3.14 >>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class SingleArgModel(HyperparametersMixin): ... def __init__(self, params): ... super().__init__() ... # manually assign single argument ... self.save_hyperparameters(params) ... def forward(self, *args, **kwargs): ... ... >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) >>> model.hparams "p1": 1 "p2": abc "p3": 3.14 >>> from lightning.pytorch.core.mixins import HyperparametersMixin >>> class ManuallyArgsModel(HyperparametersMixin): ... def __init__(self, arg1, arg2, arg3): ... super().__init__() ... # pass argument(s) to ignore as a string or in a list ... self.save_hyperparameters(ignore='arg2') ... def forward(self, *args, **kwargs): ... ... >>> model = ManuallyArgsModel(1, 'abc', 3.14) >>> model.hparams "arg1": 1 "arg3": 3.14 """ self._log_hyperparams = logger # the frame needs to be created in this file. if not frame: current_frame = inspect.currentframe() if current_frame: frame = current_frame.f_back save_hyperparameters(self, *args, ignore=ignore, frame=frame)
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None: hp = self._to_hparams_dict(hp) if isinstance(hp, dict) and isinstance(self.hparams, dict): self.hparams.update(hp) else: self._hparams = hp @staticmethod def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]: if isinstance(hp, Namespace): hp = vars(hp) if isinstance(hp, dict): hp = AttributeDict(hp) elif isinstance(hp, _PRIMITIVE_TYPES): raise ValueError(f"Primitives {_PRIMITIVE_TYPES} are not allowed.") elif not isinstance(hp, _ALLOWED_CONFIG_TYPES): raise ValueError(f"Unsupported config type of {type(hp)}.") return hp @property def hparams(self) -> Union[AttributeDict, MutableMapping]: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For the frozen set of initial hyperparameters, use :attr:`hparams_initial`. Returns: Mutable hyperparameters dictionary """ if not hasattr(self, "_hparams"): self._hparams = AttributeDict() return self._hparams @property def hparams_initial(self) -> AttributeDict: """The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only. Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`. Returns: AttributeDict: immutable initial hyperparameters """ if not hasattr(self, "_hparams_initial"): return AttributeDict() # prevent any change return copy.deepcopy(self._hparams_initial)

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.