Implementing Gradient Skipping

Hi, I would like to implement gradient skipping in PL, i.e. skipping training updates with a gradient norm above a certain threshold.

In other words,

  1. Calculate gradient norm of model parameters
  2. If gradient norm > thresh, decide whether or not to call optimizer.step()

Any advice on what could be the recommended way to implement this in the LightningModule?

you can override optimizer_step

def optimizer_step(self, *args, **kwargs):
    grad_norm = # calculate grad_norm
    if grad_norm > thresh:
        super().optimizer_step(*args, **kwargs)

Thanks for your reply!

However, if I were to extract the gradient information within the optimizer_step method, using something like this:

def optimizer_step(self, *args, **kwargs):
    # All p.grad are tensors of zeros
    parameters = [p for p in self.parameters() if p.grad is not None]
    grad_norm = # calculate grad_norm

    if grad_norm > thresh:
        super().optimizer_step(*args, **kwargs)

All the parameters appear to have zero accumulated gradients. What could be happening here? Is loss.backward() not yet called at this point?

ok yeah…backward pass happens in the closure… need to find a better alternative.

hi there! same problem, i would like to skip gradient and optimization computation, for this i use manual_optimizer and manual_backward (see this) this works but then loss is marked as “nan” in the progress bar. Anyway here is an example:

def training_step(self, batch, batch_nb):

    loss_dict = self.compute_loss(...)
    
    if loss_dict is None:
        return 

    loss = sum([value for key, value in loss_dict.items()])

    opt = self.optimizers()
    self.manual_backward(loss, opt)
    self.manual_optimizer_step(opt)

    logs = {'loss': loss}
    logs.update({'train_'+k:v.item() for k, v in loss_dict.items()})

    return logs

I’ve also tried this, and I get the same result that the loss is logged as “nan”. Haven’t quite figured out why this is so…

Here is a simple example in colab.

hello, i have read the solution on the github issues, for the loss being nan in the tqdm progress bar you can just do:

self.trainer.train_loop.running_loss.append(loss)

Hello there, it seems manual optimization does not work with native amp in newest version. Do we need an extra-step like overriding backward method with code in Introducing native PyTorch automatic mixed precision for faster training on NVIDIA GPUs | PyTorch ?