Callback¶
- class lightning.pytorch.callbacks.Callback[source]¶
Bases:
object
Abstract base class used to build new callbacks.
Subclass this class and override any of the relevant hooks
- load_state_dict(state_dict)[source]¶
Called when loading a checkpoint, implement to reload callback state given callback’s
state_dict
.
- on_after_backward(trainer, pl_module)[source]¶
Called after
loss.backward()
and before optimizers are stepped.- Return type
- on_before_optimizer_step(trainer, pl_module, optimizer)[source]¶
Called before
optimizer.step()
.- Return type
- on_before_zero_grad(trainer, pl_module, optimizer)[source]¶
Called before
optimizer.zero_grad()
.- Return type
- on_exception(trainer, pl_module, exception)[source]¶
Called when any trainer execution is interrupted by an exception.
- Return type
- on_load_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when loading a model checkpoint, use to reload state.
- Parameters
pl_module¶ (
LightningModule
) – the currentLightningModule
instance.checkpoint¶ (
Dict
[str
,Any
]) – the full checkpoint dictionary that got loaded by the Trainer.
- Return type
- on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch ends.
- Return type
- on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the predict batch begins.
- Return type
- on_predict_epoch_start(trainer, pl_module)[source]¶
Called when the predict epoch begins.
- Return type
- on_sanity_check_end(trainer, pl_module)[source]¶
Called when the validation sanity check ends.
- Return type
- on_sanity_check_start(trainer, pl_module)[source]¶
Called when the validation sanity check starts.
- Return type
- on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
Called when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters
pl_module¶ (
LightningModule
) – the currentLightningModule
instance.checkpoint¶ (
Dict
[str
,Any
]) – the checkpoint dictionary that will be saved.
- Return type
- on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch ends.
- Return type
- on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the test batch begins.
- Return type
- on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
Called when the train batch ends.
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.- Return type
- on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]¶
Called when the train batch begins.
- Return type
- on_train_epoch_end(trainer, pl_module)[source]¶
Called when the train epoch ends.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
pytorch_lightning.LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear()
- Return type
- on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch ends.
- Return type
- on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
Called when the validation batch begins.
- Return type
- on_validation_epoch_start(trainer, pl_module)[source]¶
Called when the val epoch begins.
- Return type
- on_validation_start(trainer, pl_module)[source]¶
Called when the validation loop begins.
- Return type
- setup(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune begins.
- Return type
- teardown(trainer, pl_module, stage)[source]¶
Called when fit, validate, test, predict, or tune ends.
- Return type
- property state_key: str¶
Identifier for the state of the callback.
Used to store and retrieve a callback’s state from the checkpoint dictionary by
checkpoint["callbacks"][state_key]
. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type