Callback¶
- class lightning.pytorch.callbacks.Callback[source]¶
- Bases: - object- Abstract base class used to build new callbacks. - Subclass this class and override any of the relevant hooks - load_state_dict(state_dict)[source]¶
- Called when loading a checkpoint, implement to reload callback state given callback’s - state_dict.
 - on_after_backward(trainer, pl_module)[source]¶
- Called after - loss.backward()and before optimizers are stepped.- Return type
 
 - on_before_optimizer_step(trainer, pl_module, optimizer)[source]¶
- Called before - optimizer.step().- Return type
 
 - on_before_zero_grad(trainer, pl_module, optimizer)[source]¶
- Called before - optimizer.zero_grad().- Return type
 
 - on_exception(trainer, pl_module, exception)[source]¶
- Called when any trainer execution is interrupted by an exception. - Return type
 
 - on_load_checkpoint(trainer, pl_module, checkpoint)[source]¶
- Called when loading a model checkpoint, use to reload state. - Parameters
- pl_module¶ ( - LightningModule) – the current- LightningModuleinstance.
- checkpoint¶ ( - Dict[- str,- Any]) – the full checkpoint dictionary that got loaded by the Trainer.
 
- Return type
 
 - on_predict_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the predict batch ends. - Return type
 
 - on_predict_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the predict batch begins. - Return type
 
 - on_predict_epoch_start(trainer, pl_module)[source]¶
- Called when the predict epoch begins. - Return type
 
 - on_sanity_check_end(trainer, pl_module)[source]¶
- Called when the validation sanity check ends. - Return type
 
 - on_sanity_check_start(trainer, pl_module)[source]¶
- Called when the validation sanity check starts. - Return type
 
 - on_save_checkpoint(trainer, pl_module, checkpoint)[source]¶
- Called when saving a checkpoint to give you a chance to store anything else you might want to save. - Parameters
- pl_module¶ ( - LightningModule) – the current- LightningModuleinstance.
- checkpoint¶ ( - Dict[- str,- Any]) – the checkpoint dictionary that will be saved.
 
- Return type
 
 - on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the test batch ends. - Return type
 
 - on_test_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the test batch begins. - Return type
 
 - on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]¶
- Called when the train batch ends. - Note - The value - outputs["loss"]here will be the normalized value w.r.t- accumulate_grad_batchesof the loss returned from- training_step.- Return type
 
 - on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]¶
- Called when the train batch begins. - Return type
 
 - on_train_epoch_end(trainer, pl_module)[source]¶
- Called when the train epoch ends. - To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the - pytorch_lightning.LightningModuleand access them in this hook:- class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss class MyCallback(L.Callback): def on_train_epoch_end(self, trainer, pl_module): # do something with all training_step outputs, for example: epoch_mean = torch.stack(pl_module.training_step_outputs).mean() pl_module.log("training_epoch_mean", epoch_mean) # free up the memory pl_module.training_step_outputs.clear() - Return type
 
 - on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the validation batch ends. - Return type
 
 - on_validation_batch_start(trainer, pl_module, batch, batch_idx, dataloader_idx=0)[source]¶
- Called when the validation batch begins. - Return type
 
 - on_validation_epoch_start(trainer, pl_module)[source]¶
- Called when the val epoch begins. - Return type
 
 - on_validation_start(trainer, pl_module)[source]¶
- Called when the validation loop begins. - Return type
 
 - setup(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune begins. - Return type
 
 - teardown(trainer, pl_module, stage)[source]¶
- Called when fit, validate, test, predict, or tune ends. - Return type
 
 - property state_key: str¶
- Identifier for the state of the callback. - Used to store and retrieve a callback’s state from the checkpoint dictionary by - checkpoint["callbacks"][state_key]. Implementations of a callback need to provide a unique state key if 1) the callback has state and 2) it is desired to maintain the state of multiple instances of that callback.- Return type