How to interpret simple profiler results?

Using profiler="simple" I got the following result (first ~10 lines):

  Action                                               |   total time
|  run_training_epoch                                  |   127.88
|  run_training_batch                                  |   71.237
|  [LightningModule]SegmentationModule.optimizer_step  |   71.221
|  [Strategy]SingleDeviceStrategy.training_step        |   54.01
|  [Strategy]SingleDeviceStrategy.validation_step      |   51.581
|  [Callback]TQDMProgressBar.on_train_batch_end        |   1.727
|  [Strategy]SingleDeviceStrategy.backward             |   1.037
|  [_TrainingEpochLoop].train_dataloader_next          |   1.022

Do I interpret it correctly that

  • ouf of the 127s training epoch, around 51s are spent on validation and 71 on training (optimizer_step). This leaves only 6s for other stuff (in fast_dev_run mode)
  • out of the 71 optimizer_step, 54 are spent on the forward pass (training_step). What are the remaining 17 seconds spent on? The backwards pass has a total time of only 1s, so likely something else in optimizer_step?

Reference: Optimization — PyTorch Lightning 2.4.0 documentation

Okay, after some number crunching and code checking, the following would make sense to me:

  • run_training_epoch = train_dataloader_next + optimizer_step + val_dataloader_next + validation_step
  • optimizer_step = training_step + zero_grad + backward + remaining-part

The remaining part of optimizer_step is likely the weight update, which happens outside of the closure defined in lightning and is thus not traced, see pytorch/torch/optim/adam.py at a8467c17c362bff41ceb0a6c1dec3c72454d242b · pytorch/pytorch · GitHub