Shortcuts

ModelHooks

class pytorch_lightning.core.hooks.ModelHooks[source]

Bases: object

Hooks to be used in LightningModule.

configure_sharded_model()[source]

Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.

This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.

Return type:

None

on_after_backward()[source]

Called after loss.backward() and before optimizers are stepped.

Note

If using native AMP, the gradients will not be unscaled at this point. Use the on_before_optimizer_step if you need the unscaled gradients.

Return type:

None

on_before_backward(loss)[source]

Called before loss.backward().

Parameters:

loss (Tensor) – Loss divided by number of batches for gradient accumulation and scaled if using native AMP.

Return type:

None

on_before_optimizer_step(optimizer, optimizer_idx)[source]

Called before optimizer.step().

If using gradient accumulation, the hook is called once the gradients have been accumulated. See: accumulate_grad_batches.

If using native AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.

If clipping gradients, the gradients will not have been clipped yet.

Parameters:
  • optimizer (Optimizer) – Current optimizer being used.

  • optimizer_idx (int) – Index of the current optimizer being used.

Example:

def on_before_optimizer_step(self, optimizer, optimizer_idx):
    # example to inspect gradient information in tensorboard
    if self.trainer.global_step % 25 == 0:  # don't make the tf file huge
        for k, v in self.named_parameters():
            self.logger.experiment.add_histogram(
                tag=k, values=v.grad, global_step=self.trainer.global_step
            )
Return type:

None

on_before_zero_grad(optimizer)[source]

Called after training_step() and before optimizer.zero_grad().

Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.

This is where it is called:

for optimizer in optimizers:
    out = training_step(...)

    model.on_before_zero_grad(optimizer) # < ---- called here
    optimizer.zero_grad()

    backward()
Parameters:

optimizer (Optimizer) – The optimizer for which grads should be zeroed.

Return type:

None

on_epoch_end()[source]

Called when either of train/val/test epoch ends.

Deprecated since version v1.6: on_epoch_end() has been deprecated in v1.6 and will be removed in v1.8. Use on_<train/validation/test>_epoch_end instead.

Return type:

None

on_epoch_start()[source]

Called when either of train/val/test epoch begins.

Deprecated since version v1.6: on_epoch_start() has been deprecated in v1.6 and will be removed in v1.8. Use on_<train/validation/test>_epoch_start instead.

Return type:

None

on_fit_end()[source]

Called at the very end of fit.

If on DDP it is called on every process

Return type:

None

on_fit_start()[source]

Called at the very beginning of fit.

If on DDP it is called on every process

Return type:

None

on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]

Called in the predict loop after the batch.

Parameters:
  • outputs (Optional[Any]) – The outputs of predict_step_end(test_step(x))

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_predict_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the predict loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_predict_end()[source]

Called at the end of predicting.

Return type:

None

on_predict_epoch_end(results)[source]

Called at the end of predicting.

Return type:

None

on_predict_epoch_start()[source]

Called at the beginning of predicting.

Return type:

None

on_predict_model_eval()[source]

Sets the model to eval during the predict loop.

Return type:

None

on_predict_start()[source]

Called at the beginning of predicting.

Return type:

None

on_pretrain_routine_end()[source]

Called at the end of the pretrain routine (between fit and train start).

  • fit

  • pretrain_routine start

  • pretrain_routine end

  • training_start

Deprecated since version v1.6: on_pretrain_routine_end() has been deprecated in v1.6 and will be removed in v1.8. Use on_fit_start instead.

Return type:

None

on_pretrain_routine_start()[source]

Called at the beginning of the pretrain routine (between fit and train start).

  • fit

  • pretrain_routine start

  • pretrain_routine end

  • training_start

Deprecated since version v1.6: on_pretrain_routine_start() has been deprecated in v1.6 and will be removed in v1.8. Use on_fit_start instead.

Return type:

None

on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]

Called in the test loop after the batch.

Parameters:
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of test_step_end(test_step(x))

  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_test_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the test loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the test DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_test_end()[source]

Called at the end of testing.

Return type:

None

on_test_epoch_end()[source]

Called in the test loop at the very end of the epoch.

Return type:

None

on_test_epoch_start()[source]

Called in the test loop at the very beginning of the epoch.

Return type:

None

on_test_model_eval()[source]

Sets the model to eval during the test loop.

Return type:

None

on_test_model_train()[source]

Sets the model to train during the test loop.

Return type:

None

on_test_start()[source]

Called at the beginning of testing.

Return type:

None

on_train_batch_end(outputs, batch, batch_idx)[source]

Called in the training loop after the batch.

Parameters:
  • outputs (Union[Tensor, Dict[str, Any]]) – The outputs of training_step_end(training_step(x))

  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type:

None

on_train_batch_start(batch, batch_idx)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters:
  • batch (Any) – The batched data as it is returned by the training DataLoader.

  • batch_idx (int) – the index of the batch

Return type:

Optional[int]

on_train_end()[source]

Called at the end of training before logger experiment is closed.

Return type:

None

on_train_epoch_end()[source]

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, either:

  1. Implement training_epoch_end in the LightningModule OR

  2. Cache data across steps on the attribute(s) of the LightningModule and access them in this hook

Return type:

None

on_train_epoch_start()[source]

Called in the training loop at the very beginning of the epoch.

Return type:

None

on_train_start()[source]

Called at the beginning of training after sanity check.

Return type:

None

on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]

Called in the validation loop after the batch.

Parameters:
  • outputs (Union[Tensor, Dict[str, Any], None]) – The outputs of validation_step_end(validation_step(x))

  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_validation_batch_start(batch, batch_idx, dataloader_idx)[source]

Called in the validation loop before anything happens for that batch.

Parameters:
  • batch (Any) – The batched data as it is returned by the validation DataLoader.

  • batch_idx (int) – the index of the batch

  • dataloader_idx (int) – the index of the dataloader

Return type:

None

on_validation_end()[source]

Called at the end of validation.

Return type:

None

on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

Return type:

None

on_validation_epoch_start()[source]

Called in the validation loop at the very beginning of the epoch.

Return type:

None

on_validation_model_eval()[source]

Sets the model to eval during the val loop.

Return type:

None

on_validation_model_train()[source]

Sets the model to train during the val loop.

Return type:

None

on_validation_start()[source]

Called at the beginning of validation.

Return type:

None