What is currently the best/correct way to train on stages? I mean, for example, that at first I train model with one loss, and after some time change it to different one. Or change augmentations.
Write a callback, which will check training status in on_epoch_start (or using some other method) and change parameters of lightning module. But I’m not sure it is safe. For example, will it be okay if I change dataloaders using such callback?
Wait until training is finished, then load the checkpoint and continue. But I’m not sure how to change optimizers/dataloaders and other things in this case.
Save model directly (pytorch pth file), then create a new instance of Lightning Module with new parameters and load the weights from file. Then train.
Which of these approaches is better, or maybe there is a different better way to do this?
I think the best option as of now is 2. Model.load_from_checkpoint() accepts additional args that override the ones loaded from the checkpoint. In this way you could do something like:
model = Model(stage1)
trainer.fit(model) # stage1
model.stage = 2 # OR
model = Model.load_from_checkpoint("stage1.ckpt", stage=2)
trainer.fit(model) # stage2
...
class Model(LightningModule):
# anything you would like to change can use self.stage
def configure_optimizers(self):
if self.stage == 1:
return ...
elif self.stage == 2:
return ...
Hopefully this helps! If this doesn’t seem like the best way perhaps you can share your exact use case and we can figure it out
For the changing dataloaders point, if you use a Datamodule, you could possibly add additional stages aside from ‘fit’ and ‘test’ and then call dm.setup(stage='fit_stage2') with a callback when you are that appropriate point in your training procedure.