Source code for pytorch_lightning.loggers.logger
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Abstract base class used to build new loggers."""
import functools
import operator
from abc import ABC
from collections import defaultdict
from typing import Any, Callable, Dict, Mapping, Optional, Sequence
import numpy as np
from lightning_fabric.loggers import Logger as FabricLogger
from lightning_fabric.loggers.logger import _DummyExperiment as DummyExperiment # for backward compatibility
from lightning_fabric.loggers.logger import rank_zero_experiment # noqa: F401 # for backward compatibility
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
[docs]class Logger(FabricLogger, ABC):
"""Base class for experiment loggers."""
[docs] def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
"""Called after model checkpoint callback saves a new checkpoint.
Args:
checkpoint_callback: the model checkpoint callback instance
"""
pass
@property
def save_dir(self) -> Optional[str]:
"""Return the root directory where experiment logs get saved, or `None` if the logger does not save data
locally."""
return None
[docs]class DummyLogger(Logger):
"""Dummy logger for internal use.
It is useful if we want to disable user's logger for a feature, but still ensure that user code can run
"""
def __init__(self) -> None:
super().__init__()
self._experiment = DummyExperiment()
@property
def experiment(self) -> DummyExperiment:
"""Return the experiment object associated with this logger."""
return self._experiment
@property
def name(self) -> str:
"""Return the experiment name."""
return ""
@property
def version(self) -> str:
"""Return the experiment version."""
return ""
def __getitem__(self, idx: int) -> "DummyLogger":
# enables self.logger[0].experiment.add_image(...)
return self
def __getattr__(self, name: str) -> Callable:
"""Allows the DummyLogger to be called with arbitrary methods, to avoid AttributeErrors."""
def method(*args: Any, **kwargs: Any) -> None:
return None
return method
# TODO: this should have been deprecated
[docs]def merge_dicts( # pragma: no cover
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping] = None,
default_func: Callable[[Sequence[float]], float] = np.mean,
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given
function.
Args:
dicts:
Sequence of dictionaries to be merged.
agg_key_funcs:
Mapping from key name to function. This function will aggregate a
list of values, obtained from the same key of all dictionaries.
If some key has no specified aggregation function, the default one
will be used. Default is: ``None`` (all keys will be aggregated by the
default function).
default_func:
Default function to aggregate keys, which are not presented in the
`agg_key_funcs` map.
Returns:
Dictionary with merged values.
Examples:
>>> import pprint
>>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}}
>>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}}
>>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}}
>>> dflt_func = min
>>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}}
>>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func))
{'a': 1.3,
'b': 2.0,
'c': 1,
'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}},
'v': 2.3}
"""
agg_key_funcs = agg_key_funcs or {}
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out: Dict = defaultdict(dict)
for k in keys:
fn = agg_key_funcs.get(k)
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None]
if isinstance(values_to_agg[0], dict):
d_out[k] = merge_dicts(values_to_agg, fn, default_func)
else:
d_out[k] = (fn or default_func)(values_to_agg)
return dict(d_out)