State dict not persisting to checkpoint file for custom callback

Hello, I’m attempting to save additional information from model training in a custom callback. However, the information is not being saved to the checkpoint file. I’ve implemented the load_state_dict and state_dict functions as outlined in the documentation here.

I have the following simple custom callback implemented:

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback

class MyCallback(Callback):

    def __init__(self) -> None:
        self.state = {"metric": None}

    def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule]) -> None:

        # do stuff with trainer and pl_module to get the metric
        updated_metric = ...

        self.state["metric"] = updated_metric

    def load_state_dict(self, state_dict):
        self.state.update(state_dict)

    def state_dict(self):
        return self.state.copy()

I initialize and train my trainer with the following callbacks:

trainer = pl.trainer(..., callbacks=[EarlyStopping(...), ModelCheckpoint(...), MyCallback(...)])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)

After training, direct inspection of the trainer’s MyCallback state indicates that the state has been properly updated with updated_metric. However, if I try and load the checkpoint created from training, the callback state does not persist and shows up as {"metric": None}. The docs linked above seem to indicate that simply implementing load_state_dict and state_dict is enough to “persist the state effectively”, but I’m not sure if I’m missing something here?

After training, I’m loading the checkpoint with:

loaded_state = torch.load(path/to/checkpoint.ckpt)

All my callbacks appear in the loaded state’s “callbacks” field, but only the ModelCheckpoint and EarlyStopping callbacks have persistent states. After stepping through MyCallback in my debugger, it seems like state_dict is getting called before on_train_end. Based on state_dict’s doc string, I know that this method gets called when a checkpoint is getting saved, which means the state dict is is getting saved before being updated by on_train_end. In light of this, I also tried implementing my metric calculation code in the on_save_checkpoint hook, but experienced the same result.

Any help would be greatly appreciated. I’m running pytorch_lightning v1.7.7.

UPDATE:

If I include my metric calculation in the on_train_epoch_end rather than the on_train_end, the state persists to the checkpoint. Is there a reason why state_dict does not get called (i.e., the callback’s state does not get saved) after the on_train_end hook? Is there a way that I can manually save the callback state after the on_train_end hook?