Hello, I’d like to collect and publish generic throughput metrics, e.g. collect the number of samples in training_step
using extract_batch_size
and duration to log()
(thus I need access to the underlying LightningModule
) a metric on training_step_end
(current, running average, last 20 steps, etc.).
An obvious choice seem to be creating a new Callback
, but I don’t want the duration metric to be skewed by the duration of other callbacks (or depend on their order). In other words I need a guarantee that the critical section only does training_step
for a LightningModule
.
Since this functionality is a cross-cutting concern it can’t be invasive - we don’t want to alter the original model (e.g. open-source), we want to wrap it.
I tried wrapping the model in a LightningModule
but quickly got stuck with the intricacies of how PyTorch Module handles attributes as well as how PyTorch Lightning detects available LightningModule
methods. Does PyTorch Lightning have a boilerplate solution for this?
I tried looking into plugins, but they seem to be tightly-coupled with the strategies.
What’s the organic PyTorch Lightning solution for a problem like this?
Thank you!