# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""MLflow Logger-------------"""importloggingimportosimportreimporttempfilefromargparseimportNamespacefrompathlibimportPathfromtimeimporttimefromtypingimportTYPE_CHECKING,Any,Callable,Dict,List,Literal,Mapping,Optional,Unionimportyamlfromlightning_utilities.core.importsimportRequirementCachefromtorchimportTensorfromlightning.fabric.utilities.loggerimport_add_prefix,_convert_params,_flatten_dictfromlightning.pytorch.callbacks.model_checkpointimportModelCheckpointfromlightning.pytorch.loggers.loggerimportLogger,rank_zero_experimentfromlightning.pytorch.loggers.utilitiesimport_scan_checkpointsfromlightning.pytorch.utilities.rank_zeroimportrank_zero_only,rank_zero_warnifTYPE_CHECKING:frommlflow.trackingimportMlflowClientlog=logging.getLogger(__name__)LOCAL_FILE_URI_PREFIX="file:"_MLFLOW_AVAILABLE=RequirementCache("mlflow>=1.0.0","mlflow")
[docs]classMLFlowLogger(Logger):"""Log using `MLflow <https://mlflow.org>`_. Install it with pip: .. code-block:: bash pip install mlflow # or mlflow-skinny .. code-block:: python from lightning.pytorch import Trainer from lightning.pytorch.loggers import MLFlowLogger mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") trainer = Trainer(logger=mlf_logger) Use the logger anywhere in your :class:`~lightning.pytorch.core.LightningModule` as follows: .. code-block:: python from lightning.pytorch import LightningModule class LitModel(LightningModule): def training_step(self, batch, batch_idx): # example self.logger.experiment.whatever_ml_flow_supports(...) def any_lightning_module_function_or_hook(self): self.logger.experiment.whatever_ml_flow_supports(...) Args: experiment_name: The name of the experiment. run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag. If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. tracking_uri: Address of local or remote tracking server. If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls back to `file:<save_dir>`. tags: A dictionary tags for the experiment. save_dir: A path to a local directory where the MLflow runs get saved. Defaults to `./mlruns` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. log_model: Log checkpoints created by :class:`~lightning.pytorch.callbacks.model_checkpoint.ModelCheckpoint` as MLFlow artifacts. * if ``log_model == 'all'``, checkpoints are logged during training. * if ``log_model == True``, checkpoints are logged at the end of training, except when :paramref:`~lightning.pytorch.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. run_id: The run identifier of the experiment. If not provided, a new run is started. Raises: ModuleNotFoundError: If required MLFlow package is not installed on the device. """LOGGER_JOIN_CHAR="-"def__init__(self,experiment_name:str="lightning_logs",run_name:Optional[str]=None,tracking_uri:Optional[str]=os.getenv("MLFLOW_TRACKING_URI"),tags:Optional[Dict[str,Any]]=None,save_dir:Optional[str]="./mlruns",log_model:Literal[True,False,"all"]=False,prefix:str="",artifact_location:Optional[str]=None,run_id:Optional[str]=None,):ifnot_MLFLOW_AVAILABLE:raiseModuleNotFoundError(str(_MLFLOW_AVAILABLE))super().__init__()ifnottracking_uri:tracking_uri=f"{LOCAL_FILE_URI_PREFIX}{save_dir}"self._experiment_name=experiment_nameself._experiment_id:Optional[str]=Noneself._tracking_uri=tracking_uriself._run_name=run_nameself._run_id=run_idself.tags=tagsself._log_model=log_modelself._logged_model_time:Dict[str,float]={}self._checkpoint_callback:Optional[ModelCheckpoint]=Noneself._prefix=prefixself._artifact_location=artifact_locationself._initialized=Falsefrommlflow.trackingimportMlflowClientself._mlflow_client=MlflowClient(tracking_uri)@property@rank_zero_experimentdefexperiment(self)->"MlflowClient":r"""Actual MLflow object. To use MLflow features in your :class:`~lightning.pytorch.core.LightningModule` do the following. Example:: self.logger.experiment.some_mlflow_function() """importmlflowifself._initialized:returnself._mlflow_clientmlflow.set_tracking_uri(self._tracking_uri)ifself._run_idisnotNone:run=self._mlflow_client.get_run(self._run_id)self._experiment_id=run.info.experiment_idself._initialized=Truereturnself._mlflow_clientifself._experiment_idisNone:expt=self._mlflow_client.get_experiment_by_name(self._experiment_name)ifexptisnotNone:self._experiment_id=expt.experiment_idelse:log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.")self._experiment_id=self._mlflow_client.create_experiment(name=self._experiment_name,artifact_location=self._artifact_location)ifself._run_idisNone:ifself._run_nameisnotNone:self.tags=self.tagsor{}frommlflow.utils.mlflow_tagsimportMLFLOW_RUN_NAMEifMLFLOW_RUN_NAMEinself.tags:log.warning(f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}.")self.tags[MLFLOW_RUN_NAME]=self._run_nameresolve_tags=_get_resolve_tags()run=self._mlflow_client.create_run(experiment_id=self._experiment_id,tags=resolve_tags(self.tags))self._run_id=run.info.run_idself._initialized=Truereturnself._mlflow_client@propertydefrun_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the run id. Returns: The run id. """_=self.experimentreturnself._run_id@propertydefexperiment_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the experiment id. Returns: The experiment id. """_=self.experimentreturnself._experiment_id
[docs]@rank_zero_onlydeflog_hyperparams(self,params:Union[Dict[str,Any],Namespace])->None:# type: ignore[override]params=_convert_params(params)params=_flatten_dict(params)frommlflow.entitiesimportParam# Truncate parameter values to 250 characters.# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0params_list=[Param(key=k,value=str(v)[:250])fork,vinparams.items()]# Log in chunks of 100 parameters (the maximum allowed by MLflow).foridxinrange(0,len(params_list),100):self.experiment.log_batch(run_id=self.run_id,params=params_list[idx:idx+100])
[docs]@rank_zero_onlydeflog_metrics(self,metrics:Mapping[str,float],step:Optional[int]=None)->None:assertrank_zero_only.rank==0,"experiment tried to log from global_rank != 0"frommlflow.entitiesimportMetricmetrics=_add_prefix(metrics,self._prefix,self.LOGGER_JOIN_CHAR)metrics_list:List[Metric]=[]timestamp_ms=int(time()*1000)fork,vinmetrics.items():ifisinstance(v,str):log.warning(f"Discarding metric with string value {k}={v}.")continuenew_k=re.sub("[^a-zA-Z0-9_/. -]+","",k)ifk!=new_k:rank_zero_warn("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."f" Replacing {k} with {new_k}.",category=RuntimeWarning,)k=new_kmetrics_list.append(Metric(key=k,value=v,timestamp=timestamp_ms,step=stepor0))self.experiment.log_batch(run_id=self.run_id,metrics=metrics_list)
[docs]@rank_zero_onlydeffinalize(self,status:str="success")->None:ifnotself._initialized:returnifstatus=="success":status="FINISHED"elifstatus=="failed":status="FAILED"elifstatus=="finished":status="FINISHED"# log checkpoints as artifactsifself._checkpoint_callback:self._scan_and_log_checkpoints(self._checkpoint_callback)ifself.experiment.get_run(self.run_id):self.experiment.set_terminated(self.run_id,status)
@propertydefsave_dir(self)->Optional[str]:"""The root file directory in which MLflow experiments are saved. Return: Local path to the root experiment directory if the tracking uri is local. Otherwise returns `None`. """ifself._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):returnself._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)returnNone@propertydefname(self)->Optional[str]:"""Get the experiment id. Returns: The experiment id. """returnself.experiment_id@propertydefversion(self)->Optional[str]:"""Get the run id. Returns: The run id. """returnself.run_id
[docs]defafter_save_checkpoint(self,checkpoint_callback:ModelCheckpoint)->None:# log checkpoints as artifactsifself._log_model=="all"orself._log_modelisTrueandcheckpoint_callback.save_top_k==-1:self._scan_and_log_checkpoints(checkpoint_callback)elifself._log_modelisTrue:self._checkpoint_callback=checkpoint_callback
def_scan_and_log_checkpoints(self,checkpoint_callback:ModelCheckpoint)->None:# get checkpoints to be saved with associated scorecheckpoints=_scan_checkpoints(checkpoint_callback,self._logged_model_time)# log iteratively all new checkpointsfort,p,s,tagincheckpoints:metadata={# Ensure .item() is called to store Tensor contents"score":s.item()ifisinstance(s,Tensor)elses,"original_filename":Path(p).name,"Checkpoint":{k:getattr(checkpoint_callback,k)forkin["monitor","mode","save_last","save_top_k","save_weights_only","_every_n_train_steps","_every_n_val_epochs",]# ensure it does not break if `Checkpoint` args changeifhasattr(checkpoint_callback,k)},}aliases=["latest","best"]ifp==checkpoint_callback.best_model_pathelse["latest"]# Artifact path on mlflowartifact_path=f"model/checkpoints/{Path(p).stem}"# Log the checkpointself.experiment.log_artifact(self._run_id,p,artifact_path)# Create a temporary directory to log on mlflowwithtempfile.TemporaryDirectory(prefix="test",suffix="test",dir=os.getcwd())astmp_dir:# Log the metadatawithopen(f"{tmp_dir}/metadata.yaml","w")astmp_file_metadata:yaml.dump(metadata,tmp_file_metadata,default_flow_style=False)# Log the aliaseswithopen(f"{tmp_dir}/aliases.txt","w")astmp_file_aliases:tmp_file_aliases.write(str(aliases))# Log the metadata and aliasesself.experiment.log_artifacts(self._run_id,tmp_dir,artifact_path)# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)self._logged_model_time[p]=t
def_get_resolve_tags()->Callable:frommlflow.trackingimportcontext# before v1.1.0ifhasattr(context,"resolve_tags"):frommlflow.tracking.contextimportresolve_tags# since v1.1.0elifhasattr(context,"registry"):frommlflow.tracking.context.registryimportresolve_tagselse:resolve_tags=lambdatags:tagsreturnresolve_tags
To analyze traffic and optimize your experience, we serve cookies on this
site. By clicking or navigating, you agree to allow our usage of cookies.
Read PyTorch Lightning's
Privacy Policy.