# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""MLflow Logger-------------"""importloggingimportosimportrefromargparseimportNamespacefromtimeimporttimefromtypingimportAny,Dict,Mapping,Optional,Unionfrompytorch_lightning.loggers.loggerimportLogger,rank_zero_experimentfrompytorch_lightning.utilities.importsimport_module_availablefrompytorch_lightning.utilities.loggerimport_add_prefix,_convert_params,_flatten_dictfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only,rank_zero_warnlog=logging.getLogger(__name__)LOCAL_FILE_URI_PREFIX="file:"_MLFLOW_AVAILABLE=_module_available("mlflow")try:importmlflowfrommlflow.trackingimportcontext,MlflowClientfrommlflow.utils.mlflow_tagsimportMLFLOW_RUN_NAME# todo: there seems to be still some remaining import error with Conda envexceptModuleNotFoundError:_MLFLOW_AVAILABLE=Falsemlflow,MlflowClient,context=None,None,NoneMLFLOW_RUN_NAME="mlflow.runName"# before v1.1.0ifhasattr(context,"resolve_tags"):frommlflow.tracking.contextimportresolve_tags# since v1.1.0elifhasattr(context,"registry"):frommlflow.tracking.context.registryimportresolve_tagselse:
[docs]defresolve_tags(tags:Optional[Dict]=None)->Optional[Dict]:""" Args: tags: A dictionary of tags to override. If specified, tags passed in this argument will override those inferred from the context. Returns: A dictionary of resolved tags. Note: See ``mlflow.tracking.context.registry`` for more details. """returntags
[docs]classMLFlowLogger(Logger):"""Log using `MLflow <https://mlflow.org>`_. Install it with pip: .. code-block:: bash pip install mlflow .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.loggers import MLFlowLogger mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") trainer = Trainer(logger=mlf_logger) Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows: .. code-block:: python from pytorch_lightning import LightningModule class LitModel(LightningModule): def training_step(self, batch, batch_idx): # example self.logger.experiment.whatever_ml_flow_supports(...) def any_lightning_module_function_or_hook(self): self.logger.experiment.whatever_ml_flow_supports(...) Args: experiment_name: The name of the experiment. run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag. If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. tracking_uri: Address of local or remote tracking server. If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls back to `file:<save_dir>`. tags: A dictionary tags for the experiment. save_dir: A path to a local directory where the MLflow runs get saved. Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. run_id: The run identifier of the experiment. If not provided, a new run is started. Raises: ModuleNotFoundError: If required MLFlow package is not installed on the device. """LOGGER_JOIN_CHAR="-"def__init__(self,experiment_name:str="lightning_logs",run_name:Optional[str]=None,tracking_uri:Optional[str]=os.getenv("MLFLOW_TRACKING_URI"),tags:Optional[Dict[str,Any]]=None,save_dir:Optional[str]="./mlruns",prefix:str="",artifact_location:Optional[str]=None,run_id:Optional[str]=None,):ifmlflowisNone:raiseModuleNotFoundError("You want to use `mlflow` logger which is not installed yet, install it with `pip install mlflow`.")super().__init__()ifnottracking_uri:tracking_uri=f"{LOCAL_FILE_URI_PREFIX}{save_dir}"self._experiment_name=experiment_nameself._experiment_id:Optional[str]=Noneself._tracking_uri=tracking_uriself._run_name=run_nameself._run_id=run_idself.tags=tagsself._prefix=prefixself._artifact_location=artifact_locationself._initialized=Falseself._mlflow_client=MlflowClient(tracking_uri)@property# type: ignore[misc]@rank_zero_experimentdefexperiment(self)->MlflowClient:r""" Actual MLflow object. To use MLflow features in your :class:`~pytorch_lightning.core.module.LightningModule` do the following. Example:: self.logger.experiment.some_mlflow_function() """ifself._initialized:returnself._mlflow_clientifself._run_idisnotNone:run=self._mlflow_client.get_run(self._run_id)self._experiment_id=run.info.experiment_idself._initialized=Truereturnself._mlflow_clientifself._experiment_idisNone:expt=self._mlflow_client.get_experiment_by_name(self._experiment_name)ifexptisnotNone:self._experiment_id=expt.experiment_idelse:log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.")self._experiment_id=self._mlflow_client.create_experiment(name=self._experiment_name,artifact_location=self._artifact_location)ifself._run_idisNone:ifself._run_nameisnotNone:self.tags=self.tagsor{}ifMLFLOW_RUN_NAMEinself.tags:log.warning(f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}.")self.tags[MLFLOW_RUN_NAME]=self._run_namerun=self._mlflow_client.create_run(experiment_id=self._experiment_id,tags=resolve_tags(self.tags))self._run_id=run.info.run_idself._initialized=Truereturnself._mlflow_client@propertydefrun_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the run id. Returns: The run id. """_=self.experimentreturnself._run_id@propertydefexperiment_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the experiment id. Returns: The experiment id. """_=self.experimentreturnself._experiment_id
[docs]@rank_zero_onlydeflog_hyperparams(self,params:Union[Dict[str,Any],Namespace])->None:params=_convert_params(params)params=_flatten_dict(params)fork,vinparams.items():iflen(str(v))>250:rank_zero_warn(f"Mlflow only allows parameters with up to 250 characters. Discard {k}={v}",category=RuntimeWarning)continueself.experiment.log_param(self.run_id,k,v)
[docs]@rank_zero_onlydeflog_metrics(self,metrics:Mapping[str,float],step:Optional[int]=None)->None:assertrank_zero_only.rank==0,"experiment tried to log from global_rank != 0"metrics=_add_prefix(metrics,self._prefix,self.LOGGER_JOIN_CHAR)timestamp_ms=int(time()*1000)fork,vinmetrics.items():ifisinstance(v,str):log.warning(f"Discarding metric with string value {k}={v}.")continuenew_k=re.sub("[^a-zA-Z0-9_/. -]+","",k)ifk!=new_k:rank_zero_warn("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."f" Replacing {k} with {new_k}.",category=RuntimeWarning,)k=new_kself.experiment.log_metric(self.run_id,k,v,timestamp_ms,step)
@propertydefsave_dir(self)->Optional[str]:"""The root file directory in which MLflow experiments are saved. Return: Local path to the root experiment directory if the tracking uri is local. Otherwise returns `None`. """ifself._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):returnself._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)@propertydefname(self)->Optional[str]:"""Get the experiment id. Returns: The experiment id. """returnself.experiment_id@propertydefversion(self)->Optional[str]:"""Get the run id. Returns: The run id. """returnself.run_id
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.