# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Abstract base class used to build new loggers."""importargparseimportfunctoolsimportoperatorfromabcimportABC,abstractmethodfromargparseimportNamespacefromfunctoolsimportwrapsfromtypingimportAny,Callable,Dict,Iterable,List,Mapping,Optional,Sequence,UnionfromweakrefimportReferenceTypeimportnumpyasnpimportpytorch_lightningasplfrompytorch_lightning.callbacks.model_checkpointimportModelCheckpointfrompytorch_lightning.utilities.rank_zeroimportrank_zero_deprecation,rank_zero_only
[docs]defrank_zero_experiment(fn:Callable)->Callable:"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""@wraps(fn)defexperiment(self):@rank_zero_onlydefget_experiment():returnfn(self)returnget_experiment()orDummyExperiment()returnexperiment
[docs]classLightningLoggerBase(ABC):"""Base class for experiment loggers. Args: agg_key_funcs: Dictionary which maps a metric name to a function, which will aggregate the metric values for the same steps. agg_default_func: Default function to aggregate metric values. If some metric name is not presented in the `agg_key_funcs` dictionary, then the `agg_default_func` will be used for aggregation. .. deprecated:: v1.6 The parameters `agg_key_funcs` and `agg_default_func` are deprecated in v1.6 and will be removed in v1.8. Note: The `agg_key_funcs` and `agg_default_func` arguments are used only when one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method. """def__init__(self,agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,agg_default_func:Optional[Callable[[Sequence[float]],float]]=None,):self._prev_step:int=-1self._metrics_to_agg:List[Dict[str,float]]=[]ifagg_key_funcs:self._agg_key_funcs=agg_key_funcsrank_zero_deprecation("The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"" and will be removed in v1.8.")else:self._agg_key_funcs={}ifagg_default_func:self._agg_default_func=agg_default_funcrank_zero_deprecation("The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"" and will be removed in v1.8.")else:self._agg_default_func=np.mean
[docs]defafter_save_checkpoint(self,checkpoint_callback:"ReferenceType[ModelCheckpoint]")->None:"""Called after model checkpoint callback saves a new checkpoint. Args: checkpoint_callback: the model checkpoint callback instance """pass
[docs]defupdate_agg_funcs(self,agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,agg_default_func:Callable[[Sequence[float]],float]=np.mean,):"""Update aggregation methods. .. deprecated:: v1.6 `update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8. Args: agg_key_funcs: Dictionary which maps a metric name to a function, which will aggregate the metric values for the same steps. agg_default_func: Default function to aggregate metric values. If some metric name is not presented in the `agg_key_funcs` dictionary, then the `agg_default_func` will be used for aggregation. """ifagg_key_funcs:self._agg_key_funcs.update(agg_key_funcs)ifagg_default_func:self._agg_default_func=agg_default_funcrank_zero_deprecation("`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.")
[docs]defagg_and_log_metrics(self,metrics:Dict[str,float],step:Optional[int]=None):"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead it aggregates them and logs only if metrics are ready to be logged. .. deprecated:: v1.6 This method is deprecated in v1.6 and will be removed in v1.8. Please use `LightningLoggerBase.log_metrics` instead. Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """self.log_metrics(metrics=metrics,step=step)
[docs]@abstractmethoddeflog_metrics(self,metrics:Dict[str,float],step:Optional[int]=None):""" Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate metrics for one specific `step`, use the :meth:`~pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics` method. Args: metrics: Dictionary with metric names as keys and measured quantities as values step: Step number at which the metrics should be recorded """pass
[docs]@abstractmethoddeflog_hyperparams(self,params:argparse.Namespace,*args,**kwargs):"""Record hyperparameters. Args: params: :class:`~argparse.Namespace` containing the hyperparameters args: Optional positional arguments, depends on the specific logger being used kwargs: Optional keyword arguments, depends on the specific logger being used """
[docs]deflog_graph(self,model:"pl.LightningModule",input_array=None)->None:"""Record model graph. Args: model: lightning model input_array: input passes to `model.forward` """pass
[docs]deffinalize(self,status:str)->None:"""Do any processing that is necessary to finalize an experiment. Args: status: Status that the experiment finished with (e.g. success, failed, aborted) """self.save()
[docs]defclose(self)->None:"""Do any cleanup that is necessary to close an experiment. See deprecation warning below. .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LightningLoggerBase.finalize` instead. """rank_zero_deprecation("`LightningLoggerBase.close` method is deprecated in v1.5 and will be removed in v1.7."" Please use `LightningLoggerBase.finalize` instead.")self.save()
@propertydefsave_dir(self)->Optional[str]:"""Return the root directory where experiment logs get saved, or `None` if the logger does not save data locally."""returnNone@propertydefgroup_separator(self):"""Return the default separator used by the logger to group the data into subfolders."""return"/"@property@abstractmethoddefname(self)->str:"""Return the experiment name."""@property@abstractmethoddefversion(self)->Union[int,str]:"""Return the experiment version."""
[docs]classLoggerCollection(LightningLoggerBase):"""The :class:`LoggerCollection` class is used to iterate all logging actions over the given `logger_iterable`. .. deprecated:: v1.6 `LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers to the Trainer and access the list via the `trainer.loggers` attribute. Args: logger_iterable: An iterable collection of loggers """def__init__(self,logger_iterable:Iterable[LightningLoggerBase]):super().__init__()self._logger_iterable=logger_iterablerank_zero_deprecation("`LoggerCollection` is deprecated in v1.6 and will be removed in v1.8. Directly pass a list of loggers"" to the Trainer and access the list via the `trainer.loggers` attribute.")def__getitem__(self,index:int)->LightningLoggerBase:returnlist(self._logger_iterable)[index]
@propertydefexperiment(self)->List[Any]:"""Returns a list of experiment objects for all the loggers in the logger collection."""return[logger.experimentforloggerinself._logger_iterable]
[docs]defclose(self)->None:""" .. deprecated:: v1.5 This method is deprecated in v1.5 and will be removed in v1.7. Please use `LoggerCollection.finalize` instead. """rank_zero_deprecation("`LoggerCollection.close` method is deprecated in v1.5 and will be removed in v1.7."" Please use `LoggerCollection.finalize` instead.")forloggerinself._logger_iterable:logger.close()
@propertydefsave_dir(self)->Optional[str]:"""Returns ``None`` as checkpoints should be saved to default / chosen location when using multiple loggers."""# Checkpoints should be saved to default / chosen location when using multiple loggersreturnNone@propertydefname(self)->str:"""Returns the unique experiment names for all the loggers in the logger collection joined by an underscore."""return"_".join(dict.fromkeys(str(logger.name)forloggerinself._logger_iterable))@propertydefversion(self)->str:"""Returns the unique experiment versions for all the loggers in the logger collection joined by an underscore."""return"_".join(dict.fromkeys(str(logger.version)forloggerinself._logger_iterable))
[docs]classDummyLogger(LightningLoggerBase):"""Dummy logger for internal use. It is useful if we want to disable user's logger for a feature, but still ensure that user code can run """def__init__(self):super().__init__()self._experiment=DummyExperiment()@propertydefexperiment(self)->DummyExperiment:"""Return the experiment object associated with this logger."""returnself._experiment
@propertydefname(self)->str:"""Return the experiment name."""return""@propertydefversion(self)->str:"""Return the experiment version."""return""def__getitem__(self,idx)->"DummyLogger":# enables self.logger[0].experiment.add_image(...)returnselfdef__iter__(self):# if DummyLogger is substituting a logger collection, pretend it is emptyyield from()
[docs]defmerge_dicts(dicts:Sequence[Mapping],agg_key_funcs:Optional[Mapping[str,Callable[[Sequence[float]],float]]]=None,default_func:Callable[[Sequence[float]],float]=np.mean,)->Dict:"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given function. Args: dicts: Sequence of dictionaries to be merged. agg_key_funcs: Mapping from key name to function. This function will aggregate a list of values, obtained from the same key of all dictionaries. If some key has no specified aggregation function, the default one will be used. Default is: ``None`` (all keys will be aggregated by the default function). default_func: Default function to aggregate keys, which are not presented in the `agg_key_funcs` map. Returns: Dictionary with merged values. Examples: >>> import pprint >>> d1 = {'a': 1.7, 'b': 2.0, 'c': 1, 'd': {'d1': 1, 'd3': 3}} >>> d2 = {'a': 1.1, 'b': 2.2, 'v': 1, 'd': {'d1': 2, 'd2': 3}} >>> d3 = {'a': 1.1, 'v': 2.3, 'd': {'d3': 3, 'd4': {'d5': 1}}} >>> dflt_func = min >>> agg_funcs = {'a': np.mean, 'v': max, 'd': {'d1': sum}} >>> pprint.pprint(merge_dicts([d1, d2, d3], agg_funcs, dflt_func)) {'a': 1.3, 'b': 2.0, 'c': 1, 'd': {'d1': 3, 'd2': 3, 'd3': 3, 'd4': {'d5': 1}}, 'v': 2.3} """agg_key_funcs=agg_key_funcsor{}keys=list(functools.reduce(operator.or_,[set(d.keys())fordindicts]))d_out={}forkinkeys:fn=agg_key_funcs.get(k)values_to_agg=[vforvin[d_in.get(k)ford_inindicts]ifvisnotNone]ifisinstance(values_to_agg[0],dict):d_out[k]=merge_dicts(values_to_agg,fn,default_func)else:d_out[k]=(fnordefault_func)(values_to_agg)returnd_out
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.