Is it possible to load a checkpoint halfway through the fit using a callback?

For example, I want to train for a total of 100 epochs. In the 51st epoch, instead of continuing from the model of the 50th epoch, I want to continue training from the best checkpoint within the first 50 epochs. Is it possible to do this without manually calling Trainer.fit for the first 50 epochs, and then doing Trainer.fit(ckpt_path=best_first_50.ckpt) with the checkpoint, possibly with a callback?

I think running the trainer for 50 epochs, then loading the model, and then running the (new) trainer for more epoch is the best solution.

You could try doing something like this. It is not too hacky, probably acceptable:

# pseudo code for LightningModule
def on_train_epoch_start(self):
    if self.current_epoch == 51:
        filename = self.trainer.checkpoint_callback.best_model_path
        checkpoint = torch.load(filename)
        self.load_state_dict(checkpoint["state_dict"])

I haven’t tested it, maybe there is a typo but this should work with most strategies.

2 Likes