Validate when the training_loss reaches new low, and save the model according to the validation_loss

validate when the training_loss reaches new low, and save the model according to the validation_loss

The current ModelCheckpoint can trigger a validation epoch when

  • certain number of training epochs elapse or

  • certain num of traini g steps elapse

I’m wondering is there a way to trigger a validation epoch when certain event happens(e.g. when the training_loss reaches a new low), and then decide wether or not to save the model as a checkpoint according to the validation loss ?

More precisely, we want to monitor the training process according to the following rules

  • Compute and log the training_loss at every training step

  • When the training_loss reaches a new low, start a validate epoch

  • Compute and log the validation_loss

    • when the validation_loss reaches a new low (or becomes top 3) /it’s the first validation epoch, then save the model as a checkpoint

    • if the validation_loss is not better than previously logged validation_loss, Do Nothing

  • Then get back to the training process.