Save Callback state
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback’s state as part of model checkpoint files using
state_dict()
and load_state_dict()
.
Note that the returned state must be able to be pickled.
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a state_key
property in order for Lightning
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.
class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}
@property
def state_key(self):
# note: we do not include `verbose` here on purpose
return self._generate_state_key(what=self.what)
def on_train_epoch_end(self, *args, **kwargs):
if self.what == "epochs":
self.state["epochs"] += 1
def on_train_batch_end(self, *args, **kwargs):
if self.what == "batches":
self.state["batches"] += 1
def load_state_dict(self, state_dict):
self.state.update(state_dict)
def state_dict(self):
return self.state.copy()
# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}
The implementation of a state_key
is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and state_key
by default only defines the class name as the key, e.g., here Counter
.