hey,guys , I got a question when overriding LightningModule, the newest version deleted _epoch_end functions , but remain the return in all the _step() functions. Now how to access the returned values of the *_step() functions, ,like training_step(), validation_step(), test_step()?
Developers had deleted *_epoch_end() funcions, but left the return operations in *_step(), why?
Hi @onbigion13
For example, here in the release notes in the section “The *_epoch_end
hooks were removed” you find an explanation for “why it was removed” and how you can access your outputs now.
I’ll include the code snippet here as well:
import lightning as L
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
# 1. Create a list to hold the outputs of `*_step`
self.bananas = []
def training_step(self, batch, batch_idx):
...
# 2. Add the outputs to the list
# You should be aware of the implications on memory usage
self.bananas.append(banana)
return loss
# 3. Rename the hook to `on_*_epoch_end`
def on_train_epoch_end(self):
# 4. Do something with all outputs
avg_banana = torch.cat(self.bananas).mean()
# Don't forget to clear the memory for the next epoch!
self.bananas.clear()
but left the return operations in *_step(), why?
We can’t really “remove” the return operation, but for training_step it is mandatory to return the loss for backpropagation.
Hope this helps