Callback to Set global_step and current_epoch

I am trying to design a callback to reset the trainer epoch and global_step to prespecified values, but my global_step update gets lost.

Let me first explain the context:

The use case is hyperparameter optimization (HPO) for a model initialized from a checkpoint that is taken midway through a finetuning run. Every trial must resume from the epoch the checkpoint was taken and a proper value for the global step is critical for the one-cycle lr-scheduler to work correctly. Since we are using multi-fidelity configurations and the study may be of a different fidelity than the original checkpoint, the global step must be calculated based on the steps-per-epoch for the chosen fidelity.

Each trial starts by resetting the initial model weights, then sets up a trainer, and updates the current epoch and global step through the callback, followed by executing the rest of the fit.

Specifically, my callback is as follows:

class ResetGlobalStepCallback(pl.Callback):
    def __init__(self, epoch=None, steps_per_epoch=0):
        super().__init__()
        if epoch is None:
            self._epoch = 0
            self._step = 0
        else:
            self._epoch = epoch
            self._step = (epoch+1) * steps_per_epoch
            
    def on_train_start(self, trainer, pl_module):
        trainer.fit_loop.epoch_loop._global_step = self._step      
        trainer.fit_loop.epoch_progress.current.completed = self._epoch 
 
        print("Training starts with epoch set to",trainer.current_epoch)
        print("Training starts with global step set to",trainer.global_step)

The update to the current_epoch is reflected in trainer.current_epoch, but the trainer.global_step is always zero, even when the trainer.fit_loop.epoch_loop._global_step is updated with the intended value.

I am using lightning 2.0.2 and pytorch 2.0.1.

Can someone suggest a cleaner way of doing this?