Source code for lightning.pytorch.core.mixins.hparams_mixin
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import inspect
import types
from argparse import Namespace
from typing import Any, List, MutableMapping, Optional, Sequence, Union
from lightning.fabric.utilities.data import AttributeDict
from lightning.pytorch.utilities.parsing import save_hyperparameters
_PRIMITIVE_TYPES = (bool, int, float, str)
_ALLOWED_CONFIG_TYPES = (AttributeDict, MutableMapping, Namespace)
[docs]class HyperparametersMixin:
__jit_unused_properties__: List[str] = ["hparams", "hparams_initial"]
def __init__(self) -> None:
super().__init__()
self._log_hyperparams = False
[docs] def save_hyperparameters(
self,
*args: Any,
ignore: Optional[Union[Sequence[str], str]] = None,
frame: Optional[types.FrameType] = None,
logger: bool = True,
) -> None:
"""Save arguments to ``hparams`` attribute.
Args:
args: single object of `dict`, `NameSpace` or `OmegaConf`
or string names or arguments from class ``__init__``
ignore: an argument name or a list of argument names from
class ``__init__`` to be ignored
frame: a frame object. Default is None
logger: Whether to send the hyperparameters to the logger. Default: True
Example::
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # manually assign arguments
... self.save_hyperparameters('arg1', 'arg3')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class AutomaticArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # equivalent automatic
... self.save_hyperparameters()
... def forward(self, *args, **kwargs):
... ...
>>> model = AutomaticArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg2": abc
"arg3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class SingleArgModel(HyperparametersMixin):
... def __init__(self, params):
... super().__init__()
... # manually assign single argument
... self.save_hyperparameters(params)
... def forward(self, *args, **kwargs):
... ...
>>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14))
>>> model.hparams
"p1": 1
"p2": abc
"p3": 3.14
>>> from lightning.pytorch.core.mixins import HyperparametersMixin
>>> class ManuallyArgsModel(HyperparametersMixin):
... def __init__(self, arg1, arg2, arg3):
... super().__init__()
... # pass argument(s) to ignore as a string or in a list
... self.save_hyperparameters(ignore='arg2')
... def forward(self, *args, **kwargs):
... ...
>>> model = ManuallyArgsModel(1, 'abc', 3.14)
>>> model.hparams
"arg1": 1
"arg3": 3.14
"""
self._log_hyperparams = logger
# the frame needs to be created in this file.
if not frame:
current_frame = inspect.currentframe()
if current_frame:
frame = current_frame.f_back
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
def _set_hparams(self, hp: Union[MutableMapping, Namespace, str]) -> None:
hp = self._to_hparams_dict(hp)
if isinstance(hp, dict) and isinstance(self.hparams, dict):
self.hparams.update(hp)
else:
self._hparams = hp
@staticmethod
def _to_hparams_dict(hp: Union[MutableMapping, Namespace, str]) -> Union[MutableMapping, AttributeDict]:
if isinstance(hp, Namespace):
hp = vars(hp)
if isinstance(hp, dict):
hp = AttributeDict(hp)
elif isinstance(hp, _PRIMITIVE_TYPES):
raise ValueError(f"Primitives {_PRIMITIVE_TYPES} are not allowed.")
elif not isinstance(hp, _ALLOWED_CONFIG_TYPES):
raise ValueError(f"Unsupported config type of {type(hp)}.")
return hp
@property
def hparams(self) -> Union[AttributeDict, MutableMapping]:
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. It is mutable by the user. For
the frozen set of initial hyperparameters, use :attr:`hparams_initial`.
Returns:
Mutable hyperparameters dictionary
"""
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams
@property
def hparams_initial(self) -> AttributeDict:
"""The collection of hyperparameters saved with :meth:`save_hyperparameters`. These contents are read-only.
Manual updates to the saved hyperparameters can instead be performed through :attr:`hparams`.
Returns:
AttributeDict: immutable initial hyperparameters
"""
if not hasattr(self, "_hparams_initial"):
return AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)