ModelHooks

class lightning.pytorch.core.hooks.ModelHooks[source]

Bases: object

Hooks to be used in LightningModule.

configure_model()[source]

Hook to create modules in a strategy and precision aware context.

This is particularly useful for when using sharded strategies (FSDP and DeepSpeed), where we’d like to shard the model instantly to save memory and initialization time. For non-sharded strategies, you can choose to override this hook or to initialize your model under the init_module() context manager.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent, i.e., after the first time the hook is called, subsequent calls to it should be a no-op.

Return type:

None

configure_sharded_model()[source]

Deprecated.

Use configure_model() instead.

Return type:

None

on_after_backward()[source]

Called after loss.backward() and before optimizers are stepped. :rtype: None

Note

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

on_before_backward(loss)[source]

Called before loss.backward().

Parameters:

loss (Tensor) – Loss divided by number of batches for gradient accumulation and scaled if using AMP.

Return type:

None

on_before_optimizer_step(optimizer)[source]

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:

optimizer (Optimizer) – Current optimizer being used.

Return type:

None

Example:

def on_before_optimizer_step(self, optimizer):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        for k, v in self.named_parameters():
            self.logger.experiment.add_histogram(
                tag=k, values=v.grad, global_step=self.trainer.global_step
            )
on_before_zero_grad(optimizer)[source]

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    out = training_step(...)

    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()

    backward()
Parameters:

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type:

None

on_fit_end()[source]

Called at the very end of fit.

If on DDP it is called on every process

Return type:

None

on_fit_start()[source]

Called at the very beginning of fit.

If on DDP it is called on every process

Return type:

None

on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the predict loop after the batch.

Parameters:
  • outputs (Optional[Any]) – The outputs of predict_step(x)

  • batch (Any) – The batched data as it is returned by the prediction DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_predict_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the predict loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_predict_end()[source]

Called at the end of predicting.

Return type:

None

on_predict_epoch_end()[source]

Called at the end of predicting.

Return type:

None

on_predict_epoch_start()[source]

Called at the beginning of predicting.

Return type:

None

on_predict_model_eval()[source]

Called when the predict loop starts.

The predict loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior.

Return type:

None

on_predict_start()[source]

Called at the beginning of predicting.

Return type:

None

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the test loop after the batch.

Parameters:
  • outputs (Union[Tensor, Mapping[str, Any], None]) – The outputs of test_step(x)

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_test_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the test loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_test_end()[source]

Called at the end of testing.

Return type:

None

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type:

None

on_test_epoch_start()[source]

Called in the test loop at the very beginning of the epoch.

Return type:

None

on_test_model_eval()[source]

Called when the test loop starts.

The test loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_test_model_train().

Return type:

None

on_test_model_train()[source]

Called when the test loop ends.

The test loop by default restores the training mode of the LightningModule to what it was before starting testing. Override this hook to change the behavior. See also on_test_model_eval().

Return type:

None

on_test_start()[source]

Called at the beginning of testing.

Return type:

None

on_train_batch_end(outputs, batch, batch_idx)[source]

Called in the training loop after the batch.

Parameters:
  • outputs (Union[Tensor, Mapping[str, Any], None]) – The outputs of training_step(x)

  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type:

None

on_train_batch_start(batch, batch_idx)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters:
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type:

Optional[int]

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type:

None

on_train_epoch_end()[source]

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
Return type:

None

on_train_epoch_start()[source]

Called in the training loop at the very beginning of the epoch.

Return type:

None

on_train_start()[source]

Called at the beginning of training after sanity check.

Return type:

None

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the validation loop after the batch.

Parameters:
  • outputs (Union[Tensor, Mapping[str, Any], None]) – The outputs of validation_step(x)

  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_validation_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the validation loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_validation_end()[source]

Called at the end of validation.

Return type:

None

on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

Return type:

None

on_validation_epoch_start()[source]

Called in the validation loop at the very beginning of the epoch.

Return type:

None

on_validation_model_eval()[source]

Called when the validation loop starts.

The validation loop by default calls .eval() on the LightningModule before it starts. Override this hook to change the behavior. See also on_validation_model_train().

Return type:

None

on_validation_model_train()[source]

Called when the validation loop ends.

The validation loop by default restores the training mode of the LightningModule to what it was before starting validation. Override this hook to change the behavior. See also on_validation_model_eval().

Return type:

None

on_validation_model_zero_grad()[source]

Called by the training loop to release gradients before entering the validation loop.

Return type:

None

on_validation_start()[source]

Called at the beginning of validation.

Return type:

None