# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Abstract base class used to build new loggers."""importargparseimportfunctoolsimportoperatorfromabcimportABC,abstractmethodfromargparseimportNamespacefromfunctoolsimportwrapsfromtypingimportAny,Callable,Dict,Iterable,List,Mapping,Optional,Sequence,UnionfromweakrefimportReferenceTypeimportnumpyasnpimportpytorch_lightningasplfrompytorch_lightning.callbacks.model_checkpointimportModelCheckpointfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_only
[docs]defrank_zero_experiment(fn:Callable)->Callable:"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""@wraps(fn)defexperiment(self):@rank_zero_onlydefget_experiment():returnfn(self)returnget_experiment()orDummyExperiment()returnexperiment
[docs]classLightningLoggerBase(ABC):"""Base class for experiment loggers. Args: agg_key_funcs: Dictionary which maps a metric name to a function, which will aggregate the metric values for the same steps. agg_default_func: Default function to aggregate metric values. If some metric name is not presented in the `agg_key_funcs` dictionary, then the `agg_default_func` will be used for aggregation. .. deprecated:: v1.6 The parameters `agg_key_funcs` and `agg_default_func` are deprecated in v1.6 and will be removed in v1.8. Note: The `agg_key_funcs` and `agg_default_func` arguments are used only when one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method. """def__init__(self,agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,agg_default_func:Optional[Callable[[Sequence[float]],float]]=None,):self._prev_step:int=-1self._metrics_to_agg:List[Dict[str,float]]=[]ifagg_key_funcs:self._agg_key_funcs=agg_key_funcsrank_zero_deprecation("The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"" and will be removed in v1.8.")else:self._agg_key_funcs={}ifagg_default_func:self._agg_default_func=agg_default_funcrank_zero_deprecation("The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"" and will be removed in v1.8.")else:self._agg_default_func=np.mean
[docs]defafter_save_checkpoint(self,checkpoint_callback:"ReferenceType[ModelCheckpoint]")->None:"""Called after model checkpoint callback saves a new checkpoint. Args: checkpoint_callback: the model checkpoint callback instance """pass
[docs]defupdate_agg_funcs(self,agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,agg_default_func:Callable[[Sequence[float]],float]=np.mean,):"""Update aggregation methods. .. deprecated:: v1.6 `update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8. Args: agg_key_funcs: Dictionary which maps a metric name to a function, which will aggregate the metric values for the same steps. agg_default_func: Default function to aggregate metric values. If some metric name is not presented in the `agg_key_funcs` dictionary, then the `agg_default_func` will be used for aggregation. """ifagg_key_funcs:self._agg_key_funcs.update(agg_key_funcs)ifagg_default_func:self._agg_default_func=agg_default_funcrank_zero_deprecation("`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.")
[docs]defagg_and_log_metrics(self,metrics:Dict[str,float],step:Optional[int]=None):"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead it aggregates them and logs only if metrics are ready to be logged. .. deprecated:: v1.6 This method is deprecated in v1.6 and will be removed in v1.8. Please use `LightningLoggerBase.log_metrics` instead. Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """self.log_metrics(metrics=metrics,step=step)
[docs]@abstractmethoddeflog_metrics(self,metrics:Dict[str,float],step:Optional[int]=None):""" Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific `step`, use the :meth:`~pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics` method. Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """pass
[docs]@abstractmethoddeflog_hyperparams(self,params:argparse.Namespace,*args,**kwargs):"""Record hyperparameters. Args: params: :class:`~argparse.Namespace` containing the hyperparameters args: Optional positional arguments, depends on the specific logger being used kwargs: Optional keyword arguments, depends on the specific logger being used """
[docs]deflog_graph(self,model:"pl.LightningModule",input_array=None)->None:"""Record model graph. Args: model: lightning model input_array: input passes to `model.forward` """pass
[docs]deffinalize(self,status:str)->None:"""Do any processing that is necessary to finalize an experiment. Args: status: Status that the experiment finished with (e.g. success, failed, aborted) """self.save()
[docs]defclose(self)->None:"""Do any cleanup that is necessary to close an experiment. See deprecation warning below. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LightningLoggerBase.finalize` instead. """rank_zero_deprecation("`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7."" Please use `LightningLoggerBase.finalize` instead.")self.save()
@propertydefsave_dir(self)->Optional[str]:"""Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally."""returnNone@propertydefgroup_separator(self):"""Return the default separator used by the logger to group the data into subfolders."""return"/"@property@abstractmethoddefname(self)->str:"""Return the experiment name."""@property@abstractmethoddefversion(self)->Union[int,str]:"""Return the experiment version."""
[docs]classLoggerCollection(LightningLoggerBase):"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. .. deprecated:: v1.6 `LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute. Args: logger_iterable: An iterable collection of loggers """def__init__(self,logger_iterable:Iterable[LightningLoggerBase]):super().__init__()self._logger_iterable=logger_iterablerank_zero_deprecation("`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers"" to the Trainer and access the list via the `trainer.loggers` attribute.")def__getitem__(self,index:int)->LightningLoggerBase:returnlist(self._logger_iterable)[index]
@propertydefexperiment(self)->List[Any]:"""Returns a list of experiment objects for all the loggers in the logger collection."""return[logger.experimentforloggerinself._logger_iterable]
[docs]defclose(self)->None:""" .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LoggerCollection.finalize` instead. """rank_zero_deprecation("`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7."" Please use `LoggerCollection.finalize` instead.")forloggerinself._logger_iterable:logger.close()
@propertydefsave_dir(self)->Optional[str]:"""Returns ``None`` as checkpoints should be saved to default / chosen location when using multiple loggers."""# Checkpoints should be saved to default / chosen location when using multiple loggersreturnNone@propertydefname(self)->str:"""Returns the unique experiment names for all the loggers in the logger collection joined by an underscore."""return"_".join(dict.fromkeys(str(logger.name)forloggerinself._logger_iterable))@propertydefversion(self)->str:"""Returns the unique experiment versions for all the loggers in the logger collection joined by an underscore."""return"_".join(dict.fromkeys(str(logger.version)forloggerinself._logger_iterable))
[docs]classDummyLogger(LightningLoggerBase):"""Dummy logger for internal use. It is useful if we want to disable user's logger for a feature, but still ensure that user code can run """def__init__(self):super().__init__()self._experiment=DummyExperiment()@propertydefexperiment(self)->DummyExperiment:"""Return the experiment object associated with this logger."""returnself._experiment
@propertydefname(self)->str:"""Return the experiment name."""return""@propertydefversion(self)->str:"""Return the experiment version."""return""def__getitem__(self,idx)->"DummyLogger":# enables self.logger[0].experiment.add_image(...)returnselfdef__iter__(self):# if DummyLogger is substituting a logger collection, pretend it is emptyyield from()
[docs]defmerge_dicts(dicts:Sequence[Mapping],agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,default_func:Callable[[Sequence[float]],float]=np.mean,)->Dict:"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: dicts: Sequence of dictionaries to be merged. agg_key_funcs: Mapping from key name to function. This function will aggregate a list of values, obtained from the same key of all dictionaries. If some key has no specified aggregation function, the default one will be used. Default is: ``None`` (all keys will be aggregated by the default function). default_func: Default function to aggregate keys, which are not presented in the `agg_key_funcs` map. Returns: Dictionary with merged values. Examples: >>> import pprint >>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}} >>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}} >>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}} >>> dflt_func = min >>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}} >>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func)) {'a': 1.3, 'b': 2.0, 'c': 1, 'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}}, 'v': 2.3} """agg_key_funcs=agg_key_funcsor{}keys=list(functools.reduce(operator.or_,[set(d.keys())fordindicts]))d_out={}forkinkeys:fn=agg_key_funcs.get(k)values_to_agg=[vforvin[d_in.get(k)ford_inindicts]ifvisnotNone]ifisinstance(values_to_agg[0],dict):d_out[k]=merge_dicts(values_to_agg,fn,default_func)else:d_out[k]=(fnordefault_func)(values_to_agg)returnd_out
