Hello, I’m attempting to save additional information from model training in a custom callback. However, the information is not being saved to the checkpoint file. I’ve implemented the load_state_dict
and state_dict
functions as outlined in the documentation here.
I have the following simple custom callback implemented:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
class MyCallback(Callback):
def __init__(self) -> None:
self.state = {"metric": None}
def on_train_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule]) -> None:
# do stuff with trainer and pl_module to get the metric
updated_metric = ...
self.state["metric"] = updated_metric
def load_state_dict(self, state_dict):
self.state.update(state_dict)
def state_dict(self):
return self.state.copy()
I initialize and train my trainer with the following callbacks:
trainer = pl.trainer(..., callbacks=[EarlyStopping(...), ModelCheckpoint(...), MyCallback(...)])
trainer.fit(model, train_dataloaders=train_loader, val_dataloaders=val_loader)
After training, direct inspection of the trainer’s MyCallback state indicates that the state has been properly updated with updated_metric. However, if I try and load the checkpoint created from training, the callback state does not persist and shows up as {"metric": None}
. The docs linked above seem to indicate that simply implementing load_state_dict
and state_dict
is enough to “persist the state effectively”, but I’m not sure if I’m missing something here?
After training, I’m loading the checkpoint with:
loaded_state = torch.load(path/to/checkpoint.ckpt)
All my callbacks appear in the loaded state’s “callbacks” field, but only the ModelCheckpoint and EarlyStopping callbacks have persistent states. After stepping through MyCallback in my debugger, it seems like state_dict
is getting called before on_train_end
. Based on state_dict
’s doc string, I know that this method gets called when a checkpoint is getting saved, which means the state dict is is getting saved before being updated by on_train_end
. In light of this, I also tried implementing my metric calculation code in the on_save_checkpoint
hook, but experienced the same result.
Any help would be greatly appreciated. I’m running pytorch_lightning v1.7.7.