ModelHooks¶
- class pytorch_lightning.core.hooks.ModelHooks[source]¶
Bases:
object
Hooks to be used in LightningModule.
- configure_sharded_model()[source]¶
Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.
- Return type:
- on_after_backward()[source]¶
Called after
loss.backward()
and before optimizers are stepped. :rtype:None
Note
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_step
if you need the unscaled gradients.
- on_before_optimizer_step(optimizer, optimizer_idx)[source]¶
Called before
optimizer.step()
.If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches
.If using native AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.
If clipping gradients, the gradients will not have been clipped yet.
- Parameters:
- Return type:
Example:
def on_before_optimizer_step(self, optimizer, optimizer_idx): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )
- on_before_zero_grad(optimizer)[source]¶
Called after
training_step()
and beforeoptimizer.zero_grad()
.Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- on_fit_end()[source]¶
Called at the very end of fit.
If on DDP it is called on every process
- Return type:
- on_fit_start()[source]¶
Called at the very beginning of fit.
If on DDP it is called on every process
- Return type:
- on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the predict loop after the batch.
- on_predict_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the predict loop before anything happens for that batch.
- on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the test loop after the batch.
- on_test_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the test loop before anything happens for that batch.
- on_test_epoch_start()[source]¶
Called in the test loop at the very beginning of the epoch.
- Return type:
- on_train_batch_start(batch, batch_idx)[source]¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
- on_train_end()[source]¶
Called at the end of training before logger experiment is closed.
- Return type:
- on_train_epoch_end()[source]¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, either: :rtype:
None
Implement training_epoch_end in the LightningModule OR
Cache data across steps on the attribute(s) of the LightningModule and access them in this hook
- on_train_epoch_start()[source]¶
Called in the training loop at the very beginning of the epoch.
- Return type:
- on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the validation loop after the batch.
- Parameters:
- Return type:
- on_validation_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the validation loop before anything happens for that batch.
- on_validation_epoch_end()[source]¶
Called in the validation loop at the very end of the epoch.
- Return type: