I’m trying my way around pytorch lightning and receiving different results compared to vanilla pytorch. I’m wondering whether this is due to me erroneously logging the loss. I have multiple outputs and I’m summing over the losses, so this may be the error? Here is my code. Would also be interested to get some feedback on whether the metrics calculations are correct!
class MultiOutputModule(pl.LightningModule):
def __init__(self, model: torch.nn.Module, compile_params: CompileParams):
super().__init__()
self.model = model
self.compile_params = compile_params
self.metrics = self._prepare_metrics()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self.model(x)
loss = self._compute_loss(y_hat, y)
self.log(
"train_loss", loss, on_step=True, on_epoch=True, prog_bar=True, logger=True
)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_pred = self.model(x)
loss = self._compute_loss(y_pred, y)
self.log(
"val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
)
self._update_metrics(y_pred, y)
return loss
def on_validation_epoch_end(self):
self._compute_metrics()
def configure_optimizers(self):
optimizer = self.compile_params.optimizer(
self.model.parameters(), lr=self.compile_params.start_lr
)
scheduler = self.compile_params.lr_scheduler(
optimizer, **self.compile_params.scheduler_kwargs
)
return [optimizer], [scheduler]
def _compute_loss(self, outputs: dict, labels: dict) -> torch.Tensor:
loss_values_all = []
for output_name, output_value in outputs.items():
loss_value = torch.nn.functional.mse_loss(output_value, labels[output_name])
loss_values_all.append(loss_value)
return sum(loss_values_all)
def _update_metrics(self, val_outputs: dict, val_labels: dict) -> None:
for output_name, val_output_data in val_outputs.items():
val_label_data = val_labels[output_name]
current_metrics = self.metrics[output_name]
current_metrics.update(val_output_data, val_label_data)
def _compute_metrics(self):
for output_name, metrics in self.metrics.items():
output = metrics.compute()
self.log_dict(output, prog_bar=False, logger=True)
metrics.reset()
def _prepare_metrics(self):
metrics_dict = {}
for output_name, metrics in self.compile_params.metrics.items():
curr_output_metrics = []
for metric_name in metrics:
metric = get_metric(metric_name)
curr_output_metrics.append(metric)
metrics_dict[output_name] = MetricCollection(
curr_output_metrics, prefix=output_name + "_"
)
return metrics_dict