# Copyright The Lightning AI team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""MLflow Logger-------------"""importloggingimportosimportreimporttempfilefromargparseimportNamespacefrompathlibimportPathfromtimeimporttimefromtypingimportAny,Dict,List,Mapping,Optional,Unionimportyamlfromlightning_utilities.core.importsimportRequirementCachefromtorchimportTensorfromtyping_extensionsimportLiteralfromlightning_fabric.utilities.loggerimport_add_prefix,_convert_params,_flatten_dictfrompytorch_lightning.callbacks.model_checkpointimportModelCheckpointfrompytorch_lightning.loggers.loggerimportLogger,rank_zero_experimentfrompytorch_lightning.utilities.loggerimport_scan_checkpointsfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only,rank_zero_warnlog=logging.getLogger(__name__)LOCAL_FILE_URI_PREFIX="file:"_MLFLOW_AVAILABLE=RequirementCache("mlflow>=1.0.0")if_MLFLOW_AVAILABLE:frommlflow.entitiesimportMetric,Paramfrommlflow.trackingimportcontext,MlflowClientfrommlflow.utils.mlflow_tagsimportMLFLOW_RUN_NAMEelse:MlflowClient,context=None,NoneMetric,Param=None,NoneMLFLOW_RUN_NAME="mlflow.runName"# before v1.1.0ifhasattr(context,"resolve_tags"):frommlflow.tracking.contextimportresolve_tags# since v1.1.0elifhasattr(context,"registry"):frommlflow.tracking.context.registryimportresolve_tagselse:
[docs]defresolve_tags(tags:Optional[Dict]=None)->Optional[Dict]:""" Args: tags: A dictionary of tags to override. If specified, tags passed in this argument will override those inferred from the context. Returns: A dictionary of resolved tags. Note: See ``mlflow.tracking.context.registry`` for more details. """returntags
[docs]classMLFlowLogger(Logger):"""Log using `MLflow <https://mlflow.org>`_. Install it with pip: .. code-block:: bash pip install mlflow .. code-block:: python from pytorch_lightning import Trainer from pytorch_lightning.loggers import MLFlowLogger mlf_logger = MLFlowLogger(experiment_name="lightning_logs", tracking_uri="file:./ml-runs") trainer = Trainer(logger=mlf_logger) Use the logger anywhere in your :class:`~pytorch_lightning.core.module.LightningModule` as follows: .. code-block:: python from pytorch_lightning import LightningModule class LitModel(LightningModule): def training_step(self, batch, batch_idx): # example self.logger.experiment.whatever_ml_flow_supports(...) def any_lightning_module_function_or_hook(self): self.logger.experiment.whatever_ml_flow_supports(...) Args: experiment_name: The name of the experiment. run_name: Name of the new run. The `run_name` is internally stored as a ``mlflow.runName`` tag. If the ``mlflow.runName`` tag has already been set in `tags`, the value is overridden by the `run_name`. tracking_uri: Address of local or remote tracking server. If not provided, defaults to `MLFLOW_TRACKING_URI` environment variable if set, otherwise it falls back to `file:<save_dir>`. tags: A dictionary tags for the experiment. save_dir: A path to a local directory where the MLflow runs get saved. Defaults to `./mlflow` if `tracking_uri` is not provided. Has no effect if `tracking_uri` is provided. log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` as MLFlow artifacts. * if ``log_model == 'all'``, checkpoints are logged during training. * if ``log_model == True``, checkpoints are logged at the end of training, except when :paramref:`~pytorch_lightning.callbacks.Checkpoint.save_top_k` ``== -1`` which also logs every checkpoint during training. * if ``log_model == False`` (default), no checkpoint is logged. prefix: A string to put at the beginning of metric keys. artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate default. run_id: The run identifier of the experiment. If not provided, a new run is started. Raises: ModuleNotFoundError: If required MLFlow package is not installed on the device. """LOGGER_JOIN_CHAR="-"def__init__(self,experiment_name:str="lightning_logs",run_name:Optional[str]=None,tracking_uri:Optional[str]=os.getenv("MLFLOW_TRACKING_URI"),tags:Optional[Dict[str,Any]]=None,save_dir:Optional[str]="./mlruns",log_model:Literal[True,False,"all"]=False,prefix:str="",artifact_location:Optional[str]=None,run_id:Optional[str]=None,):ifnot_MLFLOW_AVAILABLE:raiseModuleNotFoundError(str(_MLFLOW_AVAILABLE))super().__init__()ifnottracking_uri:tracking_uri=f"{LOCAL_FILE_URI_PREFIX}{save_dir}"self._experiment_name=experiment_nameself._experiment_id:Optional[str]=Noneself._tracking_uri=tracking_uriself._run_name=run_nameself._run_id=run_idself.tags=tagsself._log_model=log_modelself._logged_model_time:Dict[str,float]={}self._checkpoint_callback:Optional[ModelCheckpoint]=Noneself._prefix=prefixself._artifact_location=artifact_locationself._initialized=Falseself._mlflow_client=MlflowClient(tracking_uri)@property@rank_zero_experimentdefexperiment(self)->MlflowClient:r""" Actual MLflow object. To use MLflow features in your :class:`~pytorch_lightning.core.module.LightningModule` do the following. Example:: self.logger.experiment.some_mlflow_function() """ifself._initialized:returnself._mlflow_clientifself._run_idisnotNone:run=self._mlflow_client.get_run(self._run_id)self._experiment_id=run.info.experiment_idself._initialized=Truereturnself._mlflow_clientifself._experiment_idisNone:expt=self._mlflow_client.get_experiment_by_name(self._experiment_name)ifexptisnotNone:self._experiment_id=expt.experiment_idelse:log.warning(f"Experiment with name {self._experiment_name} not found. Creating it.")self._experiment_id=self._mlflow_client.create_experiment(name=self._experiment_name,artifact_location=self._artifact_location)ifself._run_idisNone:ifself._run_nameisnotNone:self.tags=self.tagsor{}ifMLFLOW_RUN_NAMEinself.tags:log.warning(f"The tag {MLFLOW_RUN_NAME} is found in tags. The value will be overridden by {self._run_name}.")self.tags[MLFLOW_RUN_NAME]=self._run_namerun=self._mlflow_client.create_run(experiment_id=self._experiment_id,tags=resolve_tags(self.tags))self._run_id=run.info.run_idself._initialized=Truereturnself._mlflow_client@propertydefrun_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the run id. Returns: The run id. """_=self.experimentreturnself._run_id@propertydefexperiment_id(self)->Optional[str]:"""Create the experiment if it does not exist to get the experiment id. Returns: The experiment id. """_=self.experimentreturnself._experiment_id
[docs]@rank_zero_onlydeflog_hyperparams(self,params:Union[Dict[str,Any],Namespace])->None:params=_convert_params(params)params=_flatten_dict(params)# Truncate parameter values to 250 characters.# TODO: MLflow 1.28 allows up to 500 characters: https://github.com/mlflow/mlflow/releases/tag/v1.28.0params_list=[Param(key=k,value=str(v)[:250])fork,vinparams.items()]# Log in chunks of 100 parameters (the maximum allowed by MLflow).foridxinrange(0,len(params_list),100):self.experiment.log_batch(run_id=self.run_id,params=params_list[idx:idx+100])
[docs]@rank_zero_onlydeflog_metrics(self,metrics:Mapping[str,float],step:Optional[int]=None)->None:assertrank_zero_only.rank==0,"experiment tried to log from global_rank != 0"metrics=_add_prefix(metrics,self._prefix,self.LOGGER_JOIN_CHAR)metrics_list:List[Metric]=[]timestamp_ms=int(time()*1000)fork,vinmetrics.items():ifisinstance(v,str):log.warning(f"Discarding metric with string value {k}={v}.")continuenew_k=re.sub("[^a-zA-Z0-9_/. -]+","",k)ifk!=new_k:rank_zero_warn("MLFlow only allows '_', '/', '.' and ' ' special characters in metric name."f" Replacing {k} with {new_k}.",category=RuntimeWarning,)k=new_kmetrics_list.append(Metric(key=k,value=v,timestamp=timestamp_ms,step=stepor0))self.experiment.log_batch(run_id=self.run_id,metrics=metrics_list)
[docs]@rank_zero_onlydeffinalize(self,status:str="success")->None:ifnotself._initialized:returnifstatus=="success":status="FINISHED"elifstatus=="failed":status="FAILED"elifstatus=="finished":status="FINISHED"# log checkpoints as artifactsifself._checkpoint_callback:self._scan_and_log_checkpoints(self._checkpoint_callback)ifself.experiment.get_run(self.run_id):self.experiment.set_terminated(self.run_id,status)
@propertydefsave_dir(self)->Optional[str]:"""The root file directory in which MLflow experiments are saved. Return: Local path to the root experiment directory if the tracking uri is local. Otherwise returns `None`. """ifself._tracking_uri.startswith(LOCAL_FILE_URI_PREFIX):returnself._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)@propertydefname(self)->Optional[str]:"""Get the experiment id. Returns: The experiment id. """returnself.experiment_id@propertydefversion(self)->Optional[str]:"""Get the run id. Returns: The run id. """returnself.run_id
[docs]defafter_save_checkpoint(self,checkpoint_callback:ModelCheckpoint)->None:# log checkpoints as artifactsifself._log_model=="all"orself._log_modelisTrueandcheckpoint_callback.save_top_k==-1:self._scan_and_log_checkpoints(checkpoint_callback)elifself._log_modelisTrue:self._checkpoint_callback=checkpoint_callback
def_scan_and_log_checkpoints(self,checkpoint_callback:ModelCheckpoint)->None:# get checkpoints to be saved with associated scorecheckpoints=_scan_checkpoints(checkpoint_callback,self._logged_model_time)# log iteratively all new checkpointsfort,p,s,tagincheckpoints:metadata={# Ensure .item() is called to store Tensor contents"score":s.item()ifisinstance(s,Tensor)elses,"original_filename":Path(p).name,"Checkpoint":{k:getattr(checkpoint_callback,k)forkin["monitor","mode","save_last","save_top_k","save_weights_only","_every_n_train_steps","_every_n_val_epochs",]# ensure it does not break if `Checkpoint` args changeifhasattr(checkpoint_callback,k)},}aliases=["latest","best"]ifp==checkpoint_callback.best_model_pathelse["latest"]# Artifact path on mlflowartifact_path=f"model/checkpoints/{Path(p).stem}"# Log the checkpointself.experiment.log_artifact(self._run_id,p,artifact_path)# Create a temporary directory to log on mlflowwithtempfile.TemporaryDirectory(prefix="test",suffix="test",dir=os.getcwd())astmp_dir:# Log the metadatawithopen(f"{tmp_dir}/metadata.yaml","w")astmp_file_metadata:yaml.dump(metadata,tmp_file_metadata,default_flow_style=False)# Log the aliaseswithopen(f"{tmp_dir}/aliases.txt","w")astmp_file_aliases:tmp_file_aliases.write(str(aliases))# Log the metadata and aliasesself.experiment.log_artifacts(self._run_id,tmp_dir,artifact_path)# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)self._logged_model_time[p]=t
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.