Find bottlenecks in your code (expert)

Audience: Users who want to build their own profilers.


Build your own profiler

To build your own profiler, subclass Profiler and override some of its methods. Here is a simple example that profiles the first occurrence and total calls of each action:

from lightning.pytorch.profilers import Profiler
from collections import defaultdict
import time


class ActionCountProfiler(Profiler):
    def __init__(self, dirpath=None, filename=None):
        super().__init__(dirpath=dirpath, filename=filename)
        self._action_count = defaultdict(int)
        self._action_first_occurrence = {}

    def start(self, action_name):
        if action_name not in self._action_first_occurrence:
            self._action_first_occurrence[action_name] = time.strftime("%m/%d/%Y, %H:%M:%S")

    def stop(self, action_name):
        self._action_count[action_name] += 1

    def summary(self):
        res = f"\nProfile Summary: \n"
        max_len = max(len(x) for x in self._action_count)

        for action_name in self._action_count:
            # generate summary for actions called more than once
            if self._action_count[action_name] > 1:
                res += (
                    f"{action_name:<{max_len}s} \t "
                    + "self._action_first_occurrence[action_name]} \t "
                    + "{self._action_count[action_name]} \n"
                )

        return res

    def teardown(self, stage):
        self._action_count = {}
        self._action_first_occurrence = {}
        super().teardown(stage=stage)
trainer = Trainer(profiler=ActionCountProfiler())
trainer.fit(...)

Profile custom actions of interest

To profile a specific action of interest, reference a profiler in the LightningModule.

from lightning.pytorch.profilers import SimpleProfiler, PassThroughProfiler


class MyModel(LightningModule):
    def __init__(self, profiler=None):
        self.profiler = profiler or PassThroughProfiler()

To profile in any part of your code, use the self.profiler.profile() function

class MyModel(LightningModule):
    def custom_processing_step(self, data):
        with self.profiler.profile("my_custom_action"):
            ...
        return data

Here’s the full code:

from lightning.pytorch.profilers import SimpleProfiler, PassThroughProfiler


class MyModel(LightningModule):
    def __init__(self, profiler=None):
        self.profiler = profiler or PassThroughProfiler()

    def custom_processing_step(self, data):
        with self.profiler.profile("my_custom_action"):
            ...
        return data


profiler = SimpleProfiler()
model = MyModel(profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)