Shortcuts

ModelHooks

class lightning.pytorch.core.hooks.ModelHooks[source]

Bases: object

Hooks to be used in LightningModule.

configure_sharded_model()[source]

Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.

Return type

None

on_after_backward()[source]

Called after loss.backward() and before optimizers are stepped.

Note

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

Return type

None

on_before_backward(loss)[source]

Called before loss.backward().

Parameters

loss (Tensor) – Loss divided by number of batches for gradient accumulation and scaled if using AMP.

Return type

None

on_before_optimizer_step(optimizer)[source]

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Parameters

optimizer (Optimizer) – Current optimizer being used.

Example:

def on_before_optimizer_step(self, optimizer):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        for k, v in self.named_parameters():
            self.logger.experiment.add_histogram(
                tag=k, values=v.grad, global_step=self.trainer.global_step
            )
Return type

None

on_before_zero_grad(optimizer)[source]

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    out = training_step(...)

    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()

    backward()
Parameters

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type

None

on_fit_end()[source]

Called at the very end of fit.

If on DDP it is called on every process

Return type

None

on_fit_start()[source]

Called at the very beginning of fit.

If on DDP it is called on every process

Return type

None

on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the predict loop after the batch.

Parameters
  • outputs (Optional[Any]) – The outputs of predict_step(x)

  • batch (Any) – The batched data as it is returned by the prediction DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_predict_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the predict loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_predict_end()[source]

Called at the end of predicting.

Return type

None

on_predict_epoch_end()[source]

Called at the end of predicting.

Return type

None

on_predict_epoch_start()[source]

Called at the beginning of predicting.

Return type

None

on_predict_model_eval()[source]

Sets the model to eval during the predict loop.

Return type

None

on_predict_start()[source]

Called at the beginning of predicting.

Return type

None

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the test loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of test_step(x)

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the test loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_test_end()[source]

Called at the end of testing.

Return type

None

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type

None

on_test_epoch_start()[source]

Called in the test loop at the very beginning of the epoch.

Return type

None

on_test_model_eval()[source]

Sets the model to eval during the test loop.

Return type

None

on_test_model_train()[source]

Sets the model to train during the test loop.

Return type

None

on_test_start()[source]

Called at the beginning of testing.

Return type

None

on_train_batch_end(outputs, batch, batch_idx)[source]

Called in the training loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any]]) – The outputs of training_step(x)

  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type

None

on_train_batch_start(batch, batch_idx)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type

Optional[int]

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type

None

on_train_epoch_end()[source]

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
Return type

None

on_train_epoch_start()[source]

Called in the training loop at the very beginning of the epoch.

Return type

None

on_train_start()[source]

Called at the beginning of training after sanity check.

Return type

None

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]

Called in the validation loop after the batch.

Parameters
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of validation_step(x)

  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_batch_start(batch, batch_idx, dataloader_idx=0)[source]

Called in the validation loop before anything happens for that batch.

Parameters
  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type

None

on_validation_end()[source]

Called at the end of validation.

Return type

None

on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

Return type

None

on_validation_epoch_start()[source]

Called in the validation loop at the very beginning of the epoch.

Return type

None

on_validation_model_eval()[source]

Sets the model to eval during the val loop.

Return type

None

on_validation_model_train()[source]

Sets the model to train during the val loop.

Return type

None

on_validation_start()[source]

Called at the beginning of validation.

Return type

None