ModelHooks¶
- class lightning.pytorch.core.hooks.ModelHooks[source]¶
Bases:
object
Hooks to be used in LightningModule.
- configure_sharded_model()[source]¶
Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.
- Return type:
- on_after_backward()[source]¶
Called after
loss.backward()
and before optimizers are stepped. :rtype:None
Note
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_step
if you need the unscaled gradients.
- on_before_optimizer_step(optimizer)[source]¶
Called before
optimizer.step()
.If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches
.If using AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.
If clipping gradients, the gradients will not have been clipped yet.
Example:
def on_before_optimizer_step(self, optimizer): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )
- on_before_zero_grad(optimizer)[source]¶
Called after
training_step()
and beforeoptimizer.zero_grad()
.Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- on_fit_end()[source]¶
Called at the very end of fit.
If on DDP it is called on every process
- Return type:
- on_fit_start()[source]¶
Called at the very beginning of fit.
If on DDP it is called on every process
- Return type:
- on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the predict loop after the batch.
- on_predict_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the predict loop before anything happens for that batch.
- on_test_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the test loop after the batch.
- on_test_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the test loop before anything happens for that batch.
- on_test_epoch_start()[source]¶
Called in the test loop at the very beginning of the epoch.
- Return type:
- on_train_batch_start(batch, batch_idx)[source]¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
- on_train_end()[source]¶
Called at the end of training before logger experiment is closed.
- Return type:
- on_train_epoch_end()[source]¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- Return type:
- on_train_epoch_start()[source]¶
Called in the training loop at the very beginning of the epoch.
- Return type:
- on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx=0)[source]¶
Called in the validation loop after the batch.
- on_validation_batch_start(batch, batch_idx, dataloader_idx=0)[source]¶
Called in the validation loop before anything happens for that batch.
- on_validation_epoch_end()[source]¶
Called in the validation loop at the very end of the epoch.
- Return type: