Why does training fails with "require grad and does not have a grad_fn"?

I am running a Temporal Fusion Transformer model using a custom data module that provides data as torch.tensor objects. The loss I am using is a QuantileLoss as it is used in pytorch_forecasting (highly customised metric).

Training begins and runs until last step. It throws the following RuntimeError:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
Epoch 0: 100%|█████████▉| 540/541 [00:31<00:00, 17.00it/s, v_num=561]    

What I could identify already:

  • it is not the optimizer (switching, disabling does not make a difference)
  • manual_optimization mode is possible (but not desirable)
  • the LAST loss is somehow a print(f"{self.loss.requires_grad}") => False' wheras it was always True` in previous steps.
  • Any parameters within the model are checked using:
for module in self.modules():
    test = list(module.parameters())
    if np.sum([not a.requires_grad for a in test]) > 0:
        print(f"{module} passed grad check")

The trainings loop is very basic:

    def training_step(self, batch, batch_idx):
        """Train step on batch."""
        y_hat = self(batch)
        loss = self.loss(
        return loss

The y_hat has set requires_grad correctly

In [3]: for key in y_hat.keys():
   ...:     print(f"{key}: {y_hat[key].requires_grad}")
predicted_quantiles: True
static_weights: True
historical_selection_weights: True
future_selection_weights: True
attention_scores: True

So how does it come that suddenly my loss is without gradient information? I am helpless here… May someone has an idea which pipeline can infer my loss in that way? Where may I look, how do I fix it?

Best, Falco

Sometimes it just needs a push… Solved it already by myself.
There was an update in the loss function that reassigned the loss to a freshly initialized loss. That fails as expected. The line now reads with the error line commented:

            if not torch.isfinite(losses):
                losses = losses.fill_(1e9)
                # losses = torch.tensor(1e9, device=losses.device)
1 Like