How to create a checkpoint that detects whether gradients explode, and if so, rolls back to the last checkpoint and resets the optimizer?

I’m getting terribly unstable learning in my SSL training:

I want to create a callback that detects whether gradients are exploding, and if so, rolls back to the last checkpoint and resets the optimizer. How do I do this?

Hi @RylanSchaeffer

You could do this by accessing trainer.checkpoint_callback.last_model_path and then load the weights back in. This can work but might be very inefficient as you would potentially restart many times and lose valuable training progress.

Another option you could try is add gradient clipping.

It could also be that you have a few bad training examples. You could return the index/filename as part of the batch, and whenever your loss explodes log that batch. Then go inspect your training example in the dataset and/or exclude it for future runs. If manually excluding it is not practical, you could also check the magnitude of your loss and skip the optimization step altogether by returning None from training_step.

if loss > threshold:
    return None

But if this happens too often, it could also lead to inefficiency or poor generalization.

I’ve already tried gradient clipping and gradient skipping. I noticed that increasing Adam’s epsilon’s from 1e-8 to 1e-3 prolongs the blowup, but it still occurred. A friend said this is a known issue with Adam, and suggested switching to Adadelta

If I did want to do this, what would the callback look like? Could you give me a sketch (I’ve never written a custom callback)

Gradient clipping doesn’t always helps, and I personally finding that decrease learn rate as training goes on is a better way to prevent gradient exploding.

I’m already using LR scheduling. It isn’t enough :frowning: