Shortcuts

Callback


A callback is a self-contained program that can be reused across projects.

Lightning has a callback system to execute callbacks when needed. Callbacks should capture NON-ESSENTIAL logic that is NOT required for your lightning module to run.

Here’s the flow of how the callback hooks are executed:

An overall Lightning system should have:

  1. Trainer for all engineering

  2. LightningModule for all research code.

  3. Callbacks for non-essential code.


Example:

from pytorch_lightning.callbacks import Callback


class MyPrintingCallback(Callback):
    def on_init_start(self, trainer):
        print("Starting to init trainer!")

    def on_init_end(self, trainer):
        print("trainer is init now")

    def on_train_end(self, trainer, pl_module):
        print("do something when training ends")


trainer = Trainer(callbacks=[MyPrintingCallback()])
Starting to init trainer!
trainer is init now

We successfully extended functionality without polluting our super clean lightning module research code.


Examples

You can do pretty much anything with callbacks.


Built-in Callbacks

Lightning has a few built-in callbacks.

Note

For a richer collection of callbacks, check out our bolts library.

BackboneFinetuning

Finetune a backbone model based on a learning rate user-defined scheduling.

BaseFinetuning

This class implements the base logic for writing your own Finetuning Callback.

BasePredictionWriter

Base class to implement how the predictions should be stored.

Callback

Abstract base class used to build new callbacks.

DeviceStatsMonitor

Automatically monitors and logs device stats during training stage.

EarlyStopping

Monitor a metric and stop training when it stops improving.

GPUStatsMonitor

Deprecated since version v1.5.

GradientAccumulationScheduler

Change gradient accumulation factor according to scheduling.

LambdaCallback

Create a simple callback on the fly using lambda functions.

LearningRateMonitor

Automatically monitor and logs learning rate for learning rate schedulers during training.

ModelCheckpoint

Save the model periodically by monitoring a quantity.

ModelPruning

Model pruning Callback, using PyTorch’s prune utilities.

ModelSummary

Generates a summary of all layers in a LightningModule.

ProgressBar

ProgressBarBase

The base class for progress bars in Lightning.

RichModelSummary

Generates a summary of all layers in a LightningModule with rich text formatting.

RichProgressBar

Create a progress bar with rich text formatting.

QuantizationAwareTraining

Quantization allows speeding up inference and decreasing memory requirements by performing computations and storing tensors at lower bitwidths (such as INT8 or FLOAT16) than floating point precision.

StochasticWeightAveraging

Implements the Stochastic Weight Averaging (SWA) Callback to average a model.

XLAStatsMonitor

Deprecated since version v1.5.


Persisting State

Some callbacks require internal state in order to function properly. You can optionally choose to persist your callback’s state as part of model checkpoint files using the callback hooks on_save_checkpoint() and on_load_checkpoint(). Note that the returned state must be able to be pickled.

When your callback is meant to be used only as a singleton callback then implementing the above two hooks is enough to persist state effectively. However, if passing multiple instances of the callback to the Trainer is supported, then the callback must define a state_key property in order for Lightning to be able to distinguish the different states when loading the callback state. This concept is best illustrated by the following example.

class Counter(Callback):
    def __init__(self, what="epochs", verbose=True):
        self.what = what
        self.verbose = verbose
        self.state = {"epochs": 0, "batches": 0}

    @property
    def state_key(self):
        # note: we do not include `verbose` here on purpose
        return self._generate_state_key(what=self.what)

    def on_train_epoch_end(self, *args, **kwargs):
        if self.what == "epochs":
            self.state["epochs"] += 1

    def on_train_batch_end(self, *args, **kwargs):
        if self.what == "batches":
            self.state["batches"] += 1

    def on_load_checkpoint(self, trainer, pl_module, callback_state):
        self.state.update(callback_state)

    def on_save_checkpoint(self, trainer, pl_module, checkpoint):
        return self.state.copy()


# two callbacks of the same type are being used
trainer = Trainer(callbacks=[Counter(what="epochs"), Counter(what="batches")])

A Lightning checkpoint from this Trainer with the two stateful callbacks will include the following information:

{
    "state_dict": ...,
    "callbacks": {
        "Counter{'what': 'batches'}": {"batches": 32, "epochs": 0},
        "Counter{'what': 'epochs'}": {"batches": 0, "epochs": 2},
        ...
    }
}

The implementation of a state_key is essential here. If it were missing, Lightning would not be able to disambiguate the state for these two callbacks, and state_key by default only defines the class name as the key, e.g., here Counter.

Best Practices

The following are best practices when using/designing callbacks.

  1. Callbacks should be isolated in their functionality.

  2. Your callback should not rely on the behavior of other callbacks in order to work properly.

  3. Do not manually call methods from the callback.

  4. Directly calling methods (eg. on_validation_end) is strongly discouraged.

  5. Whenever possible, your callbacks should not depend on the order in which they are executed.


Available Callback hooks

setup

Callback.setup(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune begins.

Return type

None

teardown

Callback.teardown(trainer, pl_module, stage=None)[source]

Called when fit, validate, test, predict, or tune ends.

Return type

None

on_init_start

Callback.on_init_start(trainer)[source]

Called when the trainer initialization begins, model has not yet been set.

Return type

None

on_init_end

Callback.on_init_end(trainer)[source]

Called when the trainer initialization ends, model has not yet been set.

Return type

None

on_fit_start

Callback.on_fit_start(trainer, pl_module)[source]

Called when fit begins.

Return type

None

on_fit_end

Callback.on_fit_end(trainer, pl_module)[source]

Called when fit ends.

Return type

None

on_sanity_check_start

Callback.on_sanity_check_start(trainer, pl_module)[source]

Called when the validation sanity check starts.

Return type

None

on_sanity_check_end

Callback.on_sanity_check_end(trainer, pl_module)[source]

Called when the validation sanity check ends.

Return type

None

on_train_batch_start

Callback.on_train_batch_start(trainer, pl_module, batch, batch_idx, unused=0)[source]

Called when the train batch begins.

Return type

None

on_train_batch_end

Callback.on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx, unused=0)[source]

Called when the train batch ends.

Return type

None

on_train_epoch_start

Callback.on_train_epoch_start(trainer, pl_module)[source]

Called when the train epoch begins.

Return type

None

on_train_epoch_end

Callback.on_train_epoch_end(trainer, pl_module)[source]

Called when the train epoch ends.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule and access outputs via the module OR

  2. Cache data across train batch hooks inside the callback implementation to post-process in this hook.

Return type

None

on_validation_epoch_start

Callback.on_validation_epoch_start(trainer, pl_module)[source]

Called when the val epoch begins.

Return type

None

on_validation_epoch_end

Callback.on_validation_epoch_end(trainer, pl_module)[source]

Called when the val epoch ends.

Return type

None

on_test_epoch_start

Callback.on_test_epoch_start(trainer, pl_module)[source]

Called when the test epoch begins.

Return type

None

on_test_epoch_end

Callback.on_test_epoch_end(trainer, pl_module)[source]

Called when the test epoch ends.

Return type

None

on_epoch_start

Callback.on_epoch_start(trainer, pl_module)[source]

Called when either of train/val/test epoch begins.

Return type

None

on_epoch_end

Callback.on_epoch_end(trainer, pl_module)[source]

Called when either of train/val/test epoch ends.

Return type

None

on_batch_start

Callback.on_batch_start(trainer, pl_module)[source]

Called when the training batch begins.

Return type

None

on_validation_batch_start

Callback.on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch begins.

Return type

None

on_validation_batch_end

Callback.on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the validation batch ends.

Return type

None

on_test_batch_start

Callback.on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx)[source]

Called when the test batch begins.

Return type

None

on_test_batch_end

Callback.on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx)[source]

Called when the test batch ends.

Return type

None

on_batch_end

Callback.on_batch_end(trainer, pl_module)[source]

Called when the training batch ends.

Return type

None

on_train_start

Callback.on_train_start(trainer, pl_module)[source]

Called when the train begins.

Return type

None

on_train_end

Callback.on_train_end(trainer, pl_module)[source]

Called when the train ends.

Return type

None

on_pretrain_routine_start

Callback.on_pretrain_routine_start(trainer, pl_module)[source]

Called when the pretrain routine begins.

Return type

None

on_pretrain_routine_end

Callback.on_pretrain_routine_end(trainer, pl_module)[source]

Called when the pretrain routine ends.

Return type

None

on_validation_start

Callback.on_validation_start(trainer, pl_module)[source]

Called when the validation loop begins.

Return type

None

on_validation_end

Callback.on_validation_end(trainer, pl_module)[source]

Called when the validation loop ends.

Return type

None

on_test_start

Callback.on_test_start(trainer, pl_module)[source]

Called when the test begins.

Return type

None

on_test_end

Callback.on_test_end(trainer, pl_module)[source]

Called when the test ends.

Return type

None

on_keyboard_interrupt

Callback.on_keyboard_interrupt(trainer, pl_module)[source]

Deprecated since version v1.5: This callback hook was deprecated in v1.5 in favor of on_exception and will be removed in v1.7.

Called when any trainer execution is interrupted by KeyboardInterrupt.

Return type

None

on_exception

Callback.on_exception(trainer, pl_module, exception)[source]

Called when any trainer execution is interrupted by an exception.

Return type

None

on_save_checkpoint

Callback.on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a model checkpoint, use to persist state.

Parameters
Return type

dict

Returns

The callback state.

on_load_checkpoint

Callback.on_load_checkpoint(trainer, pl_module, callback_state)[source]

Called when loading a model checkpoint, use to reload state.

Parameters

Note

The on_load_checkpoint won’t be called with an undefined state. If your on_load_checkpoint hook behavior doesn’t rely on a state, you will still need to override on_save_checkpoint to return a dummy state.

Return type

None

on_before_backward

Callback.on_before_backward(trainer, pl_module, loss)[source]

Called before loss.backward().

Return type

None

on_after_backward

Callback.on_after_backward(trainer, pl_module)[source]

Called after loss.backward() and before optimizers are stepped.

Return type

None

on_before_optimizer_step

Callback.on_before_optimizer_step(trainer, pl_module, optimizer, opt_idx)[source]

Called before optimizer.step().

Return type

None

on_before_zero_grad

Callback.on_before_zero_grad(trainer, pl_module, optimizer)[source]

Called before optimizer.zero_grad().

Return type

None