Hi! In my LightningModule
, I have a on_train_batch_end
which performs some post-processing of the results from the training_step
(for example logits normalisation and update metrics). Now I can callbacks with also on_train_batch_end
method, which expects the model outputs to be normalised.
However, I find out that the execution order is “LightningModule.training_step
→ Callback.on_train_batch_end
→ LightningModule.on_train_batch_end
”, while I was expecting to be “LightningModule.training_step
→ LightningModule.on_train_batch_end
→ Callback.on_train_batch_end
”, which seems more natural to me.
I don’t want to have this normalisation process in the training_step
, because I want to have it as an abstractumethod
so children classes overwrite the training_step
and the on_train_batch_end
can be supported out-of-the-box.
Concreate example:
Binary Segmentation model, LightningModule.training_step
produces the model logits
, in the LightningModule.on_train_batch_end
I define how these will be denormalized and pass them to the metrics, and I have a callback SegmentationWriter
which expects the predictions to be denomalized, to visualise and save them on disk. Note that I want to describe and execute the denormalisation process only once, so that the metrics and the segmentation writer will use the same one (and I dont want to repeat myself in the segmentation callback)
Question: is it intended or a specific reason which the on_train_batch_end
callback method come first from the LightningModule
, as for me it feels more natural the otherway around? Can I alter this order?
Thank you!