ModelHooks¶
- class lightning.pytorch.core.hooks.ModelHooks[source]¶
Bases:
object
Hooks to be used in LightningModule.
- configure_model()[source]¶
Hook to create modules in a strategy and precision aware context.
This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under the
init_module()
context manager.This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.
- Return type:
- configure_sharded_model()[source]¶
Deprecated.
Use
configure_model()
instead.- Return type:
- on_after_backward()[source]¶
Called after
loss.backward()
and before optimizers are stepped. :rtype:None
Note
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_step
if you need the unscaled gradients.
- on_before_optimizer_step(optimizer)[source]¶
Called before
optimizer.step()
.If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches
.If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.
If clipping gradients, the gradients will not have been clipped yet.
Example:
def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )
- on_before_zero_grad(optimizer)[source]¶
Called after
training_step()
and beforeoptimizer.zero_grad()
.Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- on_fit_end()[source]¶
Called at the very end of fit.
If on DDP it is called on every process
- Return type:
- on_fit_start()[source]¶
Called at the very beginning of fit.
If on DDP it is called on every process
- Return type:
- on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the predict loop after the batch.
- on_predict_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the predict loop before anything happens for that batch.
- on_predict_model_eval()[source]¶
Called when the predict loop starts.
The predict loop by default calls
.eval()
on the LightningModule before it starts. Override this hook to change the behavior.- Return type:
- on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the test loop after the batch.
- on_test_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the test loop before anything happens for that batch.
- on_test_epoch_start()[source]¶
Called in the test loop at the very beginning of the epoch.
- Return type:
- on_test_model_eval()[source]¶
Called when the test loop starts.
The test loop by default calls
.eval()
on the LightningModule before it starts. Override this hook to change the behavior. See alsoon_test_model_train()
.- Return type:
- on_test_model_train()[source]¶
Called when the test loop ends.
The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See also
on_test_model_eval()
.- Return type:
- on_train_batch_end(outputs, batch, batch_idx)[source]¶
Called in the training loop after the batch.
- Parameters:
- Return type:
Note
The value
outputs["loss"]
here will be the normalized value w.r.taccumulate_grad_batches
of the loss returned fromtraining_step
.
- on_train_batch_start(batch, batch_idx)[source]¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
- on_train_end()[source]¶
Called at the end of training before logger experiment is closed.
- Return type:
- on_train_epoch_end()[source]¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- Return type:
- on_train_epoch_start()[source]¶
Called in the training loop at the very beginning of the epoch.
- Return type:
- on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the validation loop after the batch.
- on_validation_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the validation loop before anything happens for that batch.
- on_validation_epoch_end()[source]¶
Called in the validation loop at the very end of the epoch.
- Return type:
- on_validation_epoch_start()[source]¶
Called in the validation loop at the very beginning of the epoch.
- Return type:
- on_validation_model_eval()[source]¶
Called when the validation loop starts.
The validation loop by default calls
.eval()
on the LightningModule before it starts. Override this hook to change the behavior. See alsoon_validation_model_train()
.- Return type:
- on_validation_model_train()[source]¶
Called when the validation loop ends.
The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See also
on_validation_model_eval()
.- Return type: