I wanted to know if this:
class MyModel(nn.Module):
...
def train_step(self, x):
out = compute_output(x)
loss = compute_loss(out, x)
return {"out": out, "loss": loss}
class LitModule(pl.LightningModule):
self.model = model
...
def training_step(self, x):
out = self.model.train_step(x)
self.log_dict({"loss": out["loss"]})
return out["loss"]
Is the same exact thing as this:
class MyModel(nn.Module):
...
def forward(self, x):
out = compute_output(x)
return out
class LitModule(pl.LightningModule):
self.model = model
self.loss = loss
...
def training_step(self, x):
out = self(x)
loss = self.loss(out, x)
self.log_dict({"loss": loss})
return loss