Is there a way to only log on epoch end using the new Result APIs? I was able to do this by forcing step to be self.current_epoch (essentially ignoring global_step ) with the old dict interface, but TrainResult doesn’t appear to have a step field?
Unfortunately, that still counts the number of training steps elapsed (I assume because it records the global_step as the step). See e.g. the tensorboard output:
The datapoints logged should be at step 0 or 1 (the last epoch) instead of 22 (the number of training batches). In other words, setting on_step=False, on_epoch=True does not affect how steps are logged.
Edit: here’s the relevant training_step implementation:
def training_step(self, batch, _):
x, y = batch
logits = self.forward(x)
loss = F.cross_entropy(logits, y)
res = pl.TrainResult(loss)
res.log("train_loss", loss, on_epoch=True, on_step=False)
return res
Interesting, I never thought about returning an EvalResult from the training step methods! Unfortunately that still doesn’t log the correct step, but I was able to do so by adding result.log('step', self.current_epoch, ...).