Callback¶
A callback is a self-contained program that can be reused across projects.
Lightning has a callback system to execute them when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.
Here’s the flow of how the callback hooks are executed:
An overall Lightning system should have:
Trainer for all engineering
LightningModule for all research code.
Callbacks for non-essential code.
Example:
from pytorch_lightning.callbacks import Callback
class MyPrintingCallback(Callback):
def on_train_start(self, trainer, pl_module):
print("Training is starting")
def on_train_end(self, trainer, pl_module):
print("Training is ending")
trainer = Trainer(callbacks=[MyPrintingCallback()])
We successfully extended functionality without polluting our super clean lightning module research code.
Examples¶
You can do pretty much anything with callbacks.
Built-in Callbacks¶
Lightning has a few built-in callbacks.
Note
For a richer collection of callbacks, check out our bolts library.
Finetune a backbone model based on a learning rate user-defined scheduling. |
|
This class implements the base logic for writing your own Finetuning Callback. |
|
Base class to implement how the predictions should be stored. |
|
Abstract base class used to build new callbacks. |
|
Automatically monitors and logs device stats during training stage. |
|
Monitor a metric and stop training when it stops improving. |
|
Deprecated since version v1.5. |
|
Change gradient accumulation factor according to scheduling. |
|
Create a simple callback on the fly using lambda functions. |
|
Automatically monitor and logs learning rate for learning rate schedulers during training. |
|
Save the model periodically by monitoring a quantity. |
|
Model pruning Callback, using PyTorch's prune utilities. |
|
Generates a summary of all layers in a |
|
The base class for progress bars in Lightning. |
|
Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision. |
|
Generates a summary of all layers in a |
|
Create a progress bar with rich text formatting. |
|
Implements the Stochastic Weight Averaging (SWA) Callback to average a model. |
|
The Timer callback tracks the time spent in the training, validation, and test loops and interrupts the Trainer if the given time limit for the training loop is reached. |
|
This is the default progress bar used by Lightning. |
|
Deprecated since version v1.5. |
Persisting State¶
Some callbacks require internal state in order to function properly. You can optionally
choose to persist your callback’s state as part of model checkpoint files using
state_dict()
and load_state_dict()
.
Note that the returned state must be able to be pickled.
When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough
to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then
the callback must define a state_key
property in order for Lightning
to be able to distinguish the different states when loading the callback state. This concept is best illustrated by
the following example.
class Counter(Callback):
def __init__(self, what="epochs", verbose=True):
self.what = what
self.verbose = verbose
self.state = {"epochs": 0, "batches": 0}
@property
def state_key(self):
# note: we do not include `verbose` here on purpose
return self._generate_state_key(what=self.what)
def on_train_epoch_end(self, *args, **kwargs):
if self.what == "epochs":
self.state["epochs"] += 1
def on_train_batch_end(self, *args, **kwargs):
if self.what == "batches":
self.state["batches"] += 1
def load_state_dict(self, state_dict):
self.state.update(state_dict)
def state_dict(self):
return self.state.copy()
# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])
A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:
{
"state_dict": ...,
"callbacks": {
"Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
"Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
...
}
}
The implementation of a state_key
is essential here. If it were missing,
Lightning would not be able to disambiguate the state for these two callbacks, and state_key
by default only defines the class name as the key, e.g., here Counter
.
Best Practices¶
The following are best practices when using/designing callbacks.
Callbacks should be isolated in their functionality.
Your callback should not rely on the behavior of other callbacks in order to work properly.
Do not manually call methods from the callback.
Directly calling methods (eg. on_validation_end) is strongly discouraged.
Whenever possible, your callbacks should not depend on the order in which they are executed.
Callback API¶
Here is the full API of methods available in the Callback base class.
The Callback
class is the base for all the callbacks in Lightning just like the LightningModule
is the base for all models.
It defines a public interface that each callback implementation must follow, the key ones are:
Properties¶
state_key¶
- Callback.state_key
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type
Hooks¶
on_configure_sharded_model¶
setup¶
teardown¶
on_init_start¶
on_init_end¶
on_fit_start¶
on_fit_end¶
on_sanity_check_start¶
on_sanity_check_end¶
on_train_batch_start¶
on_train_batch_end¶
on_train_epoch_start¶
on_train_epoch_end¶
- Callback.on_train_epoch_end(trainer, pl_module)[source]
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule and access outputs via the module OR
Cache data across train batch hooks inside the callback implementation to post-process in this hook.
- Return type
on_validation_epoch_start¶
on_validation_epoch_end¶
on_test_epoch_start¶
on_test_epoch_end¶
on_predict_epoch_start¶
on_predict_epoch_end¶
- Callback.on_predict_epoch_end(trainer, pl_module, outputs)[source]
Called when the predict epoch ends.
- Return type
on_validation_batch_start¶
on_validation_batch_end¶
on_test_batch_start¶
on_test_batch_end¶
on_predict_batch_start¶
on_predict_batch_end¶
on_train_start¶
on_train_end¶
on_validation_start¶
on_validation_end¶
on_test_start¶
on_test_end¶
on_predict_start¶
on_predict_end¶
on_keyboard_interrupt¶
on_exception¶
state_dict¶
on_save_checkpoint¶
- Callback.on_save_checkpoint(trainer, pl_module, checkpoint)[source]
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
pl_module¶ (
LightningModule
) – the currentLightningModule
instance.checkpoint¶ (
Dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type
- Returns
None or the callback state. Support for returning callback state will be removed in v1.8.
Deprecated since version v1.6: Returning a value from this method was deprecated in v1.6 and will be removed in v1.8. Implement
Callback.state_dict
instead to return state. In v1.8Callback.on_save_checkpoint
can only return None.
load_state_dict¶
on_load_checkpoint¶
- Callback.on_load_checkpoint(trainer, pl_module, callback_state)[source]
Called when loading a model checkpoint, use to reload state.
- Parameters
pl_module¶ (
LightningModule
) – the currentLightningModule
instance.callback_state¶ (
Dict
[str
,Any
]) – the callback state returned byon_save_checkpoint
.
Note
The
on_load_checkpoint
won’t be called with an undefined state. If youron_load_checkpoint
hook behavior doesn’t rely on a state, you will still need to overrideon_save_checkpoint
to return adummy state
.Deprecated since version v1.6: This callback hook will change its signature and behavior in v1.8. If you wish to load the state of the callback, use
Callback.load_state_dict
instead. In v1.8Callback.on_load_checkpoint(checkpoint)
will receive the entire loaded checkpoint dictionary instead of only the callback state from the checkpoint.- Return type