I want to add a custom profiler to torch lightning trainer (PytorchTrainer
gives runtime error saying, cant work with torch script modules). But the problem is that for some reason it seems that the start
and stop
methods are called at every batch, instead of just once. What is the correct way of writing custom profiler, below is my work around that works for now.
class CustomLightningProfiler(Profiler):
def __init__(self, dirpath: str):
super().__init__()
self.dirpath = dirpath
self.profiler = None
self.current_step = 0
self.schedule = torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=5)
self.is_active = False
self.actual_start("training")
def actual_start(self, action_name: str) -> None:
if self.profiler is None:
self.profiler = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=self.schedule,
on_trace_ready=tensorboard_trace_handler(self.dirpath),
record_shapes=True,
profile_memory=True,
with_stack=True,
with_flops=True,
with_modules=True,
)
self.profiler.start()
self.is_active = True
print("Profiler started.")
def actual_stop(self, action_name: str) -> None:
print(self.profiler, self.is_active)
if (self.profiler is not None) and self.is_active:
self.profiler.stop()
print("Profiler stopped.")
self.profiler = None
self.is_active = False
def step(self, action_name: str) -> None:
if self.profiler is not None and self.is_active:
self.profiler.step()
self.current_step += 1
print(f"Step {self.current_step} profiler step recorded.")
def start(self, action_name: str) -> None:
print("Start called.")
def stop(self, action_name: str) -> None:
print("Stop called.")
def summary(self) -> str:
return f"Profiling data saved to {self.dirpath}"
With the LightningModule
modified as:
def on_train_start(self):
print("Training started.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
self.trainer.profiler.actual_start("training")
# def on_train_batch_start(self, batch, batch_idx):
# self.trainer.profiler.step("training_batch")
def on_train_batch_end(self, outputs, batch, batch_idx):
print("Training batch end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
self.trainer.profiler.step("training_batch_end")
def on_train_end(self):
print("Training end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
self.trainer.profiler.actual_stop("training_end")
print(self.trainer.profiler.summary())
It is not perfect but it was the closest I could get it to work. I can see that my training output looks like
Stop called.
Start called.
Epoch 0: 25%|███████████████████████████████████████████▌ | 1/4 [00:08<00:24, 8.00s/it, v_num=0]Stop called.
Start called.
Stop called.
Start called.
Stop called.
Start called.
Training batch end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
profiler step recorded.
Stop called.
Start called.
Stop called.
Start called.
Stop called.
Is my understanding of start and stop wrong?