Correct way to setup custom profiler

I want to add a custom profiler to torch lightning trainer (PytorchTrainer gives runtime error saying, cant work with torch script modules). But the problem is that for some reason it seems that the start and stop methods are called at every batch, instead of just once. What is the correct way of writing custom profiler, below is my work around that works for now.

class CustomLightningProfiler(Profiler):
    def __init__(self, dirpath: str):
        super().__init__()
        self.dirpath = dirpath
        self.profiler = None
        self.current_step = 0
        self.schedule = torch.profiler.schedule(wait=1, warmup=1, active=1, repeat=5)
        self.is_active = False
        self.actual_start("training")

    def actual_start(self, action_name: str) -> None:
        if self.profiler is None:
            self.profiler = profile(
                activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
                schedule=self.schedule,
                on_trace_ready=tensorboard_trace_handler(self.dirpath),
                record_shapes=True,
                profile_memory=True,
                with_stack=True,
                with_flops=True,
                with_modules=True,
            )
        self.profiler.start()
        self.is_active = True
        print("Profiler started.")

    def actual_stop(self, action_name: str) -> None:
        print(self.profiler, self.is_active)
        if (self.profiler is not None) and self.is_active:
            self.profiler.stop()
            print("Profiler stopped.")
            self.profiler = None
            self.is_active = False

    def step(self, action_name: str) -> None:
        if self.profiler is not None and self.is_active:
            self.profiler.step()
            self.current_step += 1
            print(f"Step {self.current_step} profiler step recorded.")

    def start(self, action_name: str) -> None:
        print("Start called.")

    def stop(self, action_name: str) -> None:
        print("Stop called.")

    def summary(self) -> str:
        return f"Profiling data saved to {self.dirpath}"

With the LightningModule modified as:


    def on_train_start(self):
        print("Training started.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        self.trainer.profiler.actual_start("training")

    # def on_train_batch_start(self, batch, batch_idx):
    #     self.trainer.profiler.step("training_batch")

    def on_train_batch_end(self, outputs, batch, batch_idx):
        print("Training batch end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        self.trainer.profiler.step("training_batch_end")

    def on_train_end(self):
        print("Training end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<")
        self.trainer.profiler.actual_stop("training_end")
        print(self.trainer.profiler.summary())

It is not perfect but it was the closest I could get it to work. I can see that my training output looks like

Stop called.
Start called.
Epoch 0:  25%|███████████████████████████████████████████▌                                                                                                                                  | 1/4 [00:08<00:24,  8.00s/it, v_num=0]Stop called.
Start called.
Stop called.
Start called.
Stop called.
Start called.
Training batch end.<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<<
profiler step recorded.
Stop called.
Start called.
Stop called.
Start called.
Stop called.

Is my understanding of start and stop wrong?