Compute loss in model's forward instead of lightning module's training_step

I wanted to know if this:

class MyModel(nn.Module):
   ...
   def train_step(self, x):
      out = compute_output(x)
      loss = compute_loss(out, x)
      return {"out": out, "loss": loss}

class LitModule(pl.LightningModule):
   self.model = model
   ...
   def training_step(self, x):
      out = self.model.train_step(x)
      self.log_dict({"loss": out["loss"]})
      return out["loss"] 

Is the same exact thing as this:

class MyModel(nn.Module):
   ...
   def forward(self, x):
      out = compute_output(x)
      return out

class LitModule(pl.LightningModule):
   self.model = model
   self.loss = loss

   ...
   def training_step(self, x):
      out = self(x)
      loss = self.loss(out, x) 

      self.log_dict({"loss": loss})
      return loss

It should probably be out = self.model(x) instead of out = self(x) but yes, otherwise it looks pretty equivalent.

1 Like

I usually define the forward method inside the lightning module as:

def forward(self, x):
       return self.model(x)

It should work right? I was just worried that computing the loss inside the model’s forward would’ve given troubles during the backpropagation

You can define the loss in any module you want. You shouldn’t be worried about backpropagation for this.