Find bottlenecks in your code (expert)¶
Audience: Users who want to build their own profilers.
Build your own profiler¶
To build your own profiler, subclass Profiler
and override some of its methods. Here is a simple example that profiles the first occurrence and total calls of each action:
from lightning.pytorch.profilers import Profiler
from collections import defaultdict
import time
class ActionCountProfiler(Profiler):
def __init__(self, dirpath=None, filename=None):
super().__init__(dirpath=dirpath, filename=filename)
self._action_count = defaultdict(int)
self._action_first_occurrence = {}
def start(self, action_name):
if action_name not in self._action_first_occurrence:
self._action_first_occurrence[action_name] = time.strftime("%m/%d/%Y, %H:%M:%S")
def stop(self, action_name):
self._action_count[action_name] += 1
def summary(self):
res = f"\nProfile Summary: \n"
max_len = max(len(x) for x in self._action_count)
for action_name in self._action_count:
# generate summary for actions called more than once
if self._action_count[action_name] > 1:
res += (
f"{action_name:<{max_len}s} \t "
+ "self._action_first_occurrence[action_name]} \t "
+ "{self._action_count[action_name]} \n"
)
return res
def teardown(self, stage):
self._action_count = {}
self._action_first_occurrence = {}
super().teardown(stage=stage)
trainer = Trainer(profiler=ActionCountProfiler())
trainer.fit(...)
Profile custom actions of interest¶
To profile a specific action of interest, reference a profiler in the LightningModule.
from lightning.pytorch.profilers import SimpleProfiler, PassThroughProfiler
class MyModel(LightningModule):
def __init__(self, profiler=None):
self.profiler = profiler or PassThroughProfiler()
To profile in any part of your code, use the self.profiler.profile() function
class MyModel(LightningModule):
def custom_processing_step(self, data):
with self.profiler.profile("my_custom_action"):
...
return data
Here’s the full code:
from lightning.pytorch.profilers import SimpleProfiler, PassThroughProfiler
class MyModel(LightningModule):
def __init__(self, profiler=None):
self.profiler = profiler or PassThroughProfiler()
def custom_processing_step(self, data):
with self.profiler.profile("my_custom_action"):
...
return data
profiler = SimpleProfiler()
model = MyModel(profiler)
trainer = Trainer(profiler=profiler, max_epochs=1)