How to log metrics and losses correctly when model returns dictionary as output

I’m trying my way around pytorch lightning and receiving different results compared to vanilla pytorch. I’m wondering whether this is due to me erroneously logging the loss. I have multiple outputs and I’m summing over the losses, so this may be the error? Here is my code. Would also be interested to get some feedback on whether the metrics calculations are correct!

class MultiOutputModule(pl.LightningModule):
    def __init__(self, model: torch.nn.Module, compile_params: CompileParams):
        self.model = model
        self.compile_params = compile_params
        self.metrics = self._prepare_metrics()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.model(x)
        loss = self._compute_loss(y_hat, y)
            "train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_pred = self.model(x)
        loss = self._compute_loss(y_pred, y)
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        self._update_metrics(y_pred, y)

        return loss

    def on_validation_epoch_end(self):

    def configure_optimizers(self):
        optimizer = self.compile_params.optimizer(
            self.model.parameters(), lr=self.compile_params.start_lr
        scheduler = self.compile_params.lr_scheduler(
            optimizer, **self.compile_params.scheduler_kwargs
        return [optimizer], [scheduler]

    def _compute_loss(self, outputs: dict, labels: dict) -> torch.Tensor:
        loss_values_all = []
        for output_name, output_value in outputs.items():
            loss_value = torch.nn.functional.mse_loss(output_value, labels[output_name])
        return sum(loss_values_all)

    def _update_metrics(self, val_outputs: dict, val_labels: dict) -> None:
        for output_name, val_output_data in val_outputs.items():
            val_label_data = val_labels[output_name]
            current_metrics = self.metrics[output_name]
            current_metrics.update(val_output_data, val_label_data)

    def _compute_metrics(self):
        for output_name, metrics in self.metrics.items():
            output = metrics.compute()
            self.log_dict(output, prog_bar=False, logger=True)

    def _prepare_metrics(self):
        metrics_dict = {}
        for output_name, metrics in self.compile_params.metrics.items():
            curr_output_metrics = []
            for metric_name in metrics:
                metric = get_metric(metric_name)
            metrics_dict[output_name] = MetricCollection(
                curr_output_metrics, prefix=output_name + "_"
        return metrics_dict

Hey @emilaz
I don’t see anything immediately wrong with the code. You update the metrics during the steps and then compute + reset at the end of the epoch.

Do the metrics make sense? i.e., at the beginning of training they correspond to a random initialized model, and after training they improve?