MLFlow model can't be registered

Hi,

Only recently, I wanted to implement MLOps in our projects, and we’d chosen MLFlow since we thought it integrates seamlessly with lightning.

The problem is with registering the model. According to the mlflow docs, you should be able to register the model with the button “Register model” on UI: MLflow Model Registry — MLflow 1.3.0 documentation

Model saved with mlflow() autolog or torch lightning ModelCheckpoint doesn’t do that. It saves the model in the following format and you can see the “register button” on the artifcat is missing.

I think that MLFlow expects different format of saved model to be able to register it.

One way to circumvent this is to use mlflow.pytorch.log_model() function somewhere in the on_validation_epoch_end (either callback or lightning module). However, this means the users needs to reimplement all the functions of the ModelCheckpoint (like tracking metric, etc.). i.e.:

    def on_validation_epoch_end(self) -> None:
        with mlflow.start_run(self.logger._run_id, self.logger._experiment_id, self.logger._run_name):
            mlflow.pytorch.log_model(self, "my_model")

TLDR:
How to save the model with ModelCheckpoint in suitable format for mlflow to be able to register it?

thanks.

To log the model to MLflow, you can use MLFlowLogger, imported like below:

from lightning.pytorch.loggers import MLFlowLogger

And then, use it like this:

    with mlflow.start_run(
        run_name=config.mlflow_run_name,
        description=config.mlflow_description,
    ) as run:
        mlf_logger = MLFlowLogger(
            experiment_name=mlflow.get_experiment(run.info.experiment_id).name,
            tracking_uri=mlflow.get_tracking_uri(),
            log_model=True,
        )
        mlf_logger._run_id = run.info.run_id

        trainer = Trainer(
            callbacks=[
                EarlyStopping(
                    monitor="Val_F1_Score",
                    min_delta=config.min_delta,
                    patience=config.patience,
                    verbose=True,
                    mode="max",
                ),
                checkpoint_callback,
            ],
            default_root_dir=config.model_checkpoint_dir,
            fast_dev_run=bool(config.debug_mode_sample),
            max_epochs=config.max_epochs,
            max_time=config.max_time,
            precision="bf16-mixed" if torch.cuda.is_available() else "32-true",
            logger=mlf_logger,
        )

You can look at the full code from where the above except comes from:
Lightning-MLflow code example