.. _profiler: ######### Profiling ######### Profiling your training/testing/inference run can help you identify bottlenecks in your code. The reports can be generated with ``trainer.fit()``, ``trainer.test()``, ``trainer.validate()`` and ``trainer.predict()`` for their respective actions. ------------ **************** Built-in Actions **************** PyTorch Lightning supports profiling standard actions in the training loop out of the box, including: - on_train_epoch_start - on_train_epoch_end - on_train_batch_start - model_backward - on_after_backward - optimizer_step - on_train_batch_end - training_step_end - on_training_end - etc... ------------ ******************* Supported Profilers ******************* Lightning provides the following profilers: Simple Profiler =============== If you only wish to profile the standard actions, you can set ``profiler="simple"``. It uses the built-in :class:`~pytorch_lightning.profiler.simple.SimpleProfiler`. .. code-block:: python # by passing a string trainer = Trainer(..., profiler="simple") # or by passing an instance from pytorch_lightning.profiler import SimpleProfiler profiler = SimpleProfiler() trainer = Trainer(..., profiler=profiler) The profiler's results will be printed at the completion of a training ``trainer.fit()``. Find an example of the :class:`~pytorch_lightning.profiler.simple.SimpleProfiler` report containing a few of the actions: .. code-block:: FIT Profiler Report ----------------------------------------------------------------------------------------------- | Action | Mean duration (s) | Total time (s) | ----------------------------------------------------------------------------------------------- | run_training_epoch | 6.1558 | 6.1558 | | run_training_batch | 0.0022506 | 0.015754 | | [LightningModule]BoringModel.optimizer_step | 0.0017477 | 0.012234 | | [LightningModule]BoringModel.val_dataloader | 0.00024388 | 0.00024388 | | on_train_batch_start | 0.00014637 | 0.0010246 | | [LightningModule]BoringModel.teardown | 2.15e-06 | 2.15e-06 | | [LightningModule]BoringModel.prepare_data | 1.955e-06 | 1.955e-06 | | [LightningModule]BoringModel.on_train_start | 1.644e-06 | 1.644e-06 | | [LightningModule]BoringModel.on_train_end | 1.516e-06 | 1.516e-06 | | [LightningModule]BoringModel.on_fit_end | 1.426e-06 | 1.426e-06 | | [LightningModule]BoringModel.setup | 1.403e-06 | 1.403e-06 | | [LightningModule]BoringModel.on_fit_start | 1.226e-06 | 1.226e-06 | ----------------------------------------------------------------------------------------------- .. note:: Note that there are a lot more actions that will be present in the final report along with percentage and call count for each action. Advanced Profiler ================= If you want more information on the functions called during each event, you can use the :class:`~pytorch_lightning.profiler.advanced.AdvancedProfiler`. This option uses Python's `cProfiler <https://docs.python.org/3/library/profile.html#module-cProfile>`_ to provide an in-depth report of time spent within *each* function called in your code. .. code-block:: python # by passing a string trainer = Trainer(..., profiler="advanced") # or by passing an instance from pytorch_lightning.profiler import AdvancedProfiler profiler = AdvancedProfiler() trainer = Trainer(..., profiler=profiler) The profiler's results will be printed at the completion of ``trainer.fit()``. This profiler report can be quite long, so you can also specify a ``dirpath`` and ``filename`` to save the report instead of logging it to the output in your terminal. The output below shows the profiling for the action ``get_train_batch``. .. code-block:: Profiler Report Profile stats for: get_train_batch 4869394 function calls (4863767 primitive calls) in 18.893 seconds Ordered by: cumulative time List reduced from 76 to 10 due to restriction <10> ncalls tottime percall cumtime percall filename:lineno(function) 3752/1876 0.011 0.000 18.887 0.010 {built-in method builtins.next} 1876 0.008 0.000 18.877 0.010 dataloader.py:344(__next__) 1876 0.074 0.000 18.869 0.010 dataloader.py:383(_next_data) 1875 0.012 0.000 18.721 0.010 fetch.py:42(fetch) 1875 0.084 0.000 18.290 0.010 fetch.py:44(<listcomp>) 60000 1.759 0.000 18.206 0.000 mnist.py:80(__getitem__) 60000 0.267 0.000 13.022 0.000 transforms.py:68(__call__) 60000 0.182 0.000 7.020 0.000 transforms.py:93(__call__) 60000 1.651 0.000 6.839 0.000 functional.py:42(to_tensor) 60000 0.260 0.000 5.734 0.000 transforms.py:167(__call__) PyTorch Profiler ================ PyTorch includes a `profiler <https://pytorch.org/docs/master/profiler.html>`__ that lets you inspect the cost of different operators inside your model - both on the CPU and GPU. It's used by our :class:`~pytorch_lightning.profiler.pytorch.PyTorchProfiler`. .. code-block:: python # by passing a string trainer = Trainer(..., profiler="pytorch") # or by passing an instance from pytorch_lightning.profiler import PyTorchProfiler profiler = PyTorchProfiler() trainer = Trainer(..., profiler=profiler) This profiler works with multi-device settings. If ``filename`` is provided, each rank will save their profiled operation to their own file. The profiler report can be quite long, so you setting a ``filename`` will save the report instead of logging it to the output in your terminal. If no filename is given, it will be logged only on rank 0. The profiler's results will be printed on the completion of ``{fit,validate,test,predict}``. This profiler will record ``training_step``, ``backward``, ``validation_step``, ``test_step``, and ``predict_step`` by default. The output below shows the profiling for the action ``training_step``. The user can provide ``PyTorchProfiler(record_functions={...})`` to extend the scope of profiled functions. .. note:: When using the PyTorch Profiler, wall clock time will not not be representative of the true wall clock time. This is due to forcing profiled operations to be measured synchronously, when many CUDA ops happen asynchronously. It is recommended to use this Profiler to find bottlenecks/breakdowns, however for end to end wall clock time use the ``SimpleProfiler``. .. code-block:: Profiler Report Profile stats for: training_step --------------------- --------------- --------------- --------------- --------------- --------------- Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg --------------------- --------------- --------------- --------------- --------------- --------------- t 62.10% 1.044ms 62.77% 1.055ms 1.055ms addmm 32.32% 543.135us 32.69% 549.362us 549.362us mse_loss 1.35% 22.657us 3.58% 60.105us 60.105us mean 0.22% 3.694us 2.05% 34.523us 34.523us div_ 0.64% 10.756us 1.90% 32.001us 16.000us ones_like 0.21% 3.461us 0.81% 13.669us 13.669us sum_out 0.45% 7.638us 0.74% 12.432us 12.432us transpose 0.23% 3.786us 0.68% 11.393us 11.393us as_strided 0.60% 10.060us 0.60% 10.060us 3.353us to 0.18% 3.059us 0.44% 7.464us 7.464us empty_like 0.14% 2.387us 0.41% 6.859us 6.859us empty_strided 0.38% 6.351us 0.38% 6.351us 3.175us fill_ 0.28% 4.782us 0.33% 5.566us 2.783us expand 0.20% 3.336us 0.28% 4.743us 4.743us empty 0.27% 4.456us 0.27% 4.456us 2.228us copy_ 0.15% 2.526us 0.15% 2.526us 2.526us broadcast_tensors 0.15% 2.492us 0.15% 2.492us 2.492us size 0.06% 0.967us 0.06% 0.967us 0.484us is_complex 0.06% 0.961us 0.06% 0.961us 0.481us stride 0.03% 0.517us 0.03% 0.517us 0.517us --------------------- --------------- --------------- --------------- --------------- --------------- Self CPU time total: 1.681ms When running with ``PyTorchProfiler(emit_nvtx=True)``, you should run as following: .. code-block:: nvprof --profile-from-start off -o trace_name.prof -- <regular command here> To visualize the profiled operation, you can either: .. code-block:: nvvp trace_name.prof .. code-block:: python -c 'import torch; print(torch.autograd.profiler.load_nvprof("trace_name.prof"))' XLA Profiler ============ :class:`~pytorch_lightning.profiler.xla.XLAProfiler` will help you debug and optimize training workload performance for your models using Cloud TPU performance tools. .. code-block:: python # by passing the `XLAProfiler` alias trainer = Trainer(..., profiler="xla") # or by passing an instance from pytorch_lightning.profiler import XLAProfiler profiler = XLAProfiler(port=9001) trainer = Trainer(..., profiler=profiler) Manual Capture via TensorBoard ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ The following instructions are for capturing traces from a running program: 0. This `guide <https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm#tpu-vm>`_ will help you with the Cloud TPU setup with the required installations. 1. Start a `TensorBoard <https://www.tensorflow.org/tensorboard>`_ server. You could view the TensorBoard output at ``http://localhost:9001`` on your local machine, and then open the ``PROFILE`` plugin from the top right dropdown or open ``http://localhost:9001/#profile`` .. code-block:: bash tensorboard --logdir ./tensorboard --port 9001 2. Once the code you'd like to profile is running, click on the ``CAPTURE PROFILE`` button. Enter ``localhost:9001`` (default port for XLA Profiler) as the Profile Service URL. Then, enter the number of milliseconds for the profiling duration, and click ``CAPTURE`` 3. Make sure the code is running while you are trying to capture the traces. Also, it would lead to better performance insights if the profiling duration is longer than the step time. 4. Once the capture is finished, the page will refresh and you can browse through the insights using the ``Tools`` dropdown at the top left ---------------- **************** Custom Profiling **************** Custom Profiler =============== You can also configure a custom profiler and pass it to the Trainer. To configure it, subclass :class:`~pytorch_lightning.profiler.base.Profiler` and override some of its methods. The following is a simple example that profiles the first occurrence and total calls of each action: .. code-block:: python from pytorch_lightning.profiler import Profiler from collections import defaultdict import time class ActionCountProfiler(Profiler): def __init__(self, dirpath=None, filename=None): super().__init__(dirpath=dirpath, filename=filename) self._action_count = defaultdict(int) self._action_first_occurrence = {} def start(self, action_name): if action_name not in self._action_first_occurrence: self._action_first_occurrence[action_name] = time.strftime("%m/%d/%Y, %H:%M:%S") def stop(self, action_name): self._action_count[action_name] += 1 def summary(self): res = f"\nProfile Summary: \n" max_len = max(len(x) for x in self._action_count) for action_name in self._action_count: # generate summary for actions called more than once if self._action_count[action_name] > 1: res += ( f"{action_name:<{max_len}s} \t " + "self._action_first_occurrence[action_name]} \t " + "{self._action_count[action_name]} \n" ) return res def teardown(self, stage): self._action_count = {} self._action_first_occurrence = {} super().teardown(stage=stage) .. code-block:: python trainer = Trainer(..., profiler=ActionCountProfiler()) trainer.fit(...) Profile Logic of Your Interest ============================== You can also reference this profiler in your LightningModule to profile specific actions of interest. Each profiler has a method ``profile()`` which returns a context handler. Simply pass in the name of your action that you want to track and the profiler will record performance for code executed within this context. .. code-block:: python from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler class MyModel(LightningModule): def __init__(self, profiler=None): self.profiler = profiler or PassThroughProfiler() def custom_processing_step(self, data): with self.profiler.profile("my_custom_action"): ... return data profiler = SimpleProfiler() model = MyModel(profiler) trainer = Trainer(profiler=profiler, max_epochs=1)