Weird training logs with pytorch lightning

I am using the following pytorch lightning code with a WandbLogger. (The code is inside a LightningModule)

def training_step(self, batch, batch_idx):
    """Training step"""
    loss, acc, bleu = self._step(batch)
    self.log_dict(
        {"train/loss": loss, "train/accuracy": acc, "train/bleu_score": bleu},
        on_epoch=True,
        batch_size=batch[0].shape[1],
    )
    return loss

def validation_step(self, batch, batch_idx):
    """Validation step"""
    loss, acc, bleu = self._step(batch)
    self.log_dict(
        {"val/loss": loss, "val/accuracy": acc, "val/bleu_score": bleu},
        on_epoch=True,
        on_step=False,
        batch_size=batch[0].shape[1],
    )
    return loss

def test_step(self, batch, batch_idx):
    """Test step"""
    loss, acc, bleu = self._step(batch)
    self.log_dict(
        {"test/loss": loss, "test/accuracy": acc, "test/bleu_score": bleu},
        on_epoch=True,
        on_step=False,
        batch_size=batch[0].shape[1],
    )
    return loss

def _step(self, batch: torch.Tensor):
    source, target = batch
    logits = self(source, target[:-1, :])
    with torch.no_grad():
        bleu = self._batch_bleu(logits, target)
    logits = logits.reshape(-1, logits.shape[2])
    target = target[1:].reshape(-1)
    loss = F.cross_entropy(logits, target, ignore_index=self.source_pad_idx)
    with torch.no_grad():
        acc = accuracy(
            logits,
            target,
            task="multiclass",
            num_classes=self.hparams.target_vocab_size,
            ignore_index=self.source_pad_idx,
            top_k=1,
        )
    return loss, acc.item() * 100, bleu.item() * 100

The weird thing is I get significantly worse values for all 3 metrics in training compared to validation and testing. When I tried to run a validation epoch with the training loader, I got what one would expect: slightly better results on training, which tells me the problem is not with computing the metrics but in the way trainimg_step is logging them. I looked at the documentation of Trainer, WandbLogger, and LightningModule, but found nothing. I also tried logging with tensorboard with no change in the result.

What is lightning doing differently in training vs in evaluation when it comes to logging?

I am using: lightning2.0.0, pytorch2.0

Loss graph:

Hey
Lightning doesn’t do anything special except average your value if you specify on_epoch=True. Feel free to just log a constant value, or some other curve, to verify it yourself!

I can see that you have dim1 for the batch size, but the cross entropy loss expects the batch dimension to be in dim0. Could this be the cause?

1 Like

I did try to log a constant/some simple function and got the expected behavior, so I guess pl’s logging is not the culprit after all.

I “solved” the problem by removing my epoch-level logging from the train_step all together, replacing it with the following code:

def on_train_epoch_end(self):
    self.eval()
    loader = self.trainer.train_dataloader
    loss, acc, bleu = 0, 0, 0
    for batch in tqdm(
        loader, desc="Computing metrics", total=len(loader), ncols=100
    ):
        batch = [b.to(self.device) for b in batch]
        with torch.no_grad():
            loss_, acc_, bleu_ = self._step(batch)
        loss += loss_
        acc += acc_
        bleu += bleu_
    loss /= len(loader)
    acc /= len(loader)
    bleu /= len(loader)
    self.log_dict(
        {
            "train_epoch/loss": loss,
            "train_epoch/accuracy": acc,
            "train_epoch/bleu_score": bleu,
        },
    )

which did give the intended result.
This, combined with the fact that previously, my training metrics were comparable to validation metrics of the previous epochs, leads me to conjecture that the cause of the problem may be that the model is improving so much during one epoch, that the first few batches drag the average down.

This is also motivated by the fact that the only thing that changed compared to my previous attempt, is that now, I am not updating the model weights. There is also the fact that I am no longer specifying the btach_size parameter to the log_dict function, but I can’t see how that can make the results look worse than they actually are.

As for the shape of my batches, I think it is not the cause of problem because I am flattening everything (in the same order) before computing the loss (well, that, and also, accuracy and bleu score are increasing).

Thank you so much for your help :pray:.

1 Like