PyTorch Lightning CLI with Optuna Hyperparameter search - Hot to set PruningCallback?

Hello,

I want to use PyTorch Lightning (2.1.0) with its CLI functionality.
For a hyperparameter search, I manually set the CLI parameters in args: List[str] and pass it to the CLI. Inside args, I define the parameters that change, and fixed settings are provided by a config file.

For hyperparameter search I use Optuna. They provide a nice example of combining the two: Optuna Pytorch Lightning Example
They use the standard Python interface of the trainer instance.
On line 137 they set a special callback to the trainer. This is the PyTorchLightningPruningCallback. The callback is similar to an EarlyStopping callback, it provides the intermediate metric after each epoch to the Optuna framework, so that the framework can decide whether to stop the trial early due to bad performance. The callback needs the instance of the corresponding optuna.trial.Trial, which is given in the init_args.

However, with the CLI I don’t see how to provide the callback with the instance of the current trial.

As a workaround, I could set the trial as a member of the LightningModule after the CLI has instantiated it, and use the trial member in custom implemented hooks to call the callback hooks. But this is not the intended way (see Lightning Doc Callbacks).

So, can you help me to get the PruningCallback with CLI attached to the trainer?

I need this:

trainer = pl.Trainer(callbacks=[PyTorchLightningPruningCallback(trial, monitor="val_acc")])

to be used with Lightning CLI
or something like this

# pseudo
trainer = cli.trainer.add_callback(PyTorchLightningPruningCallback(trial, monitor="val_acc"))

Thank you

I have found a solution that works for me. But it was hidden in the documentation.

In Trainer β€” PyTorch Lightning 2.1.0 documentation there is a reference to LightningModule::configure_callbacks(). With this method I can instantiate the model, call a custom setter of my optuna.trial.Trail and implement configure_callbacks() to create my desired callback.

1 Like