I am trying to design a callback to reset the trainer epoch and global_step to prespecified values, but my global_step update gets lost.
Let me first explain the context:
The use case is hyperparameter optimization (HPO) for a model initialized from a checkpoint that is taken midway through a finetuning run. Every trial must resume from the epoch the checkpoint was taken and a proper value for the global step is critical for the one-cycle lr-scheduler to work correctly. Since we are using multi-fidelity configurations and the study may be of a different fidelity than the original checkpoint, the global step must be calculated based on the steps-per-epoch for the chosen fidelity.
Each trial starts by resetting the initial model weights, then sets up a trainer, and updates the current epoch and global step through the callback, followed by executing the rest of the fit.
Specifically, my callback is as follows:
class ResetGlobalStepCallback(pl.Callback):
def __init__(self, epoch=None, steps_per_epoch=0):
super().__init__()
if epoch is None:
self._epoch = 0
self._step = 0
else:
self._epoch = epoch
self._step = (epoch+1) * steps_per_epoch
def on_train_start(self, trainer, pl_module):
trainer.fit_loop.epoch_loop._global_step = self._step
trainer.fit_loop.epoch_progress.current.completed = self._epoch
print("Training starts with epoch set to",trainer.current_epoch)
print("Training starts with global step set to",trainer.global_step)
The update to the current_epoch is reflected in trainer.current_epoch
, but the trainer.global_step
is always zero, even when the trainer.fit_loop.epoch_loop._global_step is updated with the intended value.
I am using lightning 2.0.2 and pytorch 2.0.1.
Can someone suggest a cleaner way of doing this?