ModelHooks¶
- class pytorch_lightning.core.hooks.ModelHooks[source]¶
Bases:
object
Hooks to be used in LightningModule.
- configure_sharded_model()[source]¶
Hook to create modules in a distributed aware context. This is useful for when using sharded plugins, where we’d like to shard the model instantly, which is useful for extremely large models which can save memory and initialization time.
This hook is called during each of fit/val/test/predict stages in the same process, so ensure that implementation of this hook is idempotent.
- Return type
- on_after_backward()[source]¶
Called after
loss.backward()
and before optimizers are stepped.Note
If using native AMP, the gradients will not be unscaled at this point. Use the
on_before_optimizer_step
if you need the unscaled gradients.- Return type
- on_before_optimizer_step(optimizer, optimizer_idx)[source]¶
Called before
optimizer.step()
.If using gradient accumulation, the hook is called once the gradients have been accumulated. See:
accumulate_grad_batches
.If using native AMP, the loss will be unscaled before calling this hook. See these docs for more information on the scaling of gradients.
If clipping gradients, the gradients will not have been clipped yet.
- Parameters
Example:
def on_before_optimizer_step(self, optimizer, optimizer_idx): # example to inspect gradient information in tensorboard if self.trainer.global_step % 25 == 0: # don't make the tf file huge for k, v in self.named_parameters(): self.logger.experiment.add_histogram( tag=k, values=v.grad, global_step=self.trainer.global_step )
- Return type
- on_before_zero_grad(optimizer)[source]¶
Called after
training_step()
and beforeoptimizer.zero_grad()
.Called in the training loop after taking an optimizer step and before zeroing grads. Good place to inspect weight information with weights updated.
This is where it is called:
for optimizer in optimizers: out = training_step(...) model.on_before_zero_grad(optimizer) # < ---- called here optimizer.zero_grad() backward()
- on_epoch_end()[source]¶
Called when either of train/val/test epoch ends.
Deprecated since version v1.6:
on_epoch_end()
has been deprecated in v1.6 and will be removed in v1.8. Useon_<train/validation/test>_epoch_end
instead.- Return type
- on_epoch_start()[source]¶
Called when either of train/val/test epoch begins.
Deprecated since version v1.6:
on_epoch_start()
has been deprecated in v1.6 and will be removed in v1.8. Useon_<train/validation/test>_epoch_start
instead.- Return type
- on_fit_end()[source]¶
Called at the very end of fit.
If on DDP it is called on every process
- Return type
- on_fit_start()[source]¶
Called at the very beginning of fit.
If on DDP it is called on every process
- Return type
- on_post_move_to_device()[source]¶
Called in the
parameter_validation
decorator afterto()
is called. This is a good place to tie weights between modules after moving them to a device. Can be used when training models with weight sharing properties on TPU.Addresses the handling of shared weights on TPU: https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
Example:
def on_post_move_to_device(self): self.decoder.weight = self.encoder.weight
- Return type
- on_predict_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the predict loop after the batch.
- on_predict_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the predict loop before anything happens for that batch.
- on_pretrain_routine_end()[source]¶
Called at the end of the pretrain routine (between fit and train start).
fit
pretrain_routine start
pretrain_routine end
training_start
Deprecated since version v1.6:
on_pretrain_routine_end()
has been deprecated in v1.6 and will be removed in v1.8. Useon_fit_start
instead.- Return type
- on_pretrain_routine_start()[source]¶
Called at the beginning of the pretrain routine (between fit and train start).
fit
pretrain_routine start
pretrain_routine end
training_start
Deprecated since version v1.6:
on_pretrain_routine_start()
has been deprecated in v1.6 and will be removed in v1.8. Useon_fit_start
instead.- Return type
- on_test_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the test loop after the batch.
- on_test_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the test loop before anything happens for that batch.
- on_test_epoch_start()[source]¶
Called in the test loop at the very beginning of the epoch.
- Return type
- on_train_batch_end(outputs, batch, batch_idx, unused=0)[source]¶
Called in the training loop after the batch.
- Parameters
- Return type
- on_train_batch_start(batch, batch_idx, unused=0)[source]¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
- on_train_end()[source]¶
Called at the end of training before logger experiment is closed.
- Return type
- on_train_epoch_end()[source]¶
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, either:
Implement training_epoch_end in the LightningModule OR
Cache data across steps on the attribute(s) of the LightningModule and access them in this hook
- Return type
- on_train_epoch_start()[source]¶
Called in the training loop at the very beginning of the epoch.
- Return type
- on_validation_batch_end(outputs, batch, batch_idx, dataloader_idx)[source]¶
Called in the validation loop after the batch.
- Parameters
- Return type
- on_validation_batch_start(batch, batch_idx, dataloader_idx)[source]¶
Called in the validation loop before anything happens for that batch.
- on_validation_epoch_end()[source]¶
Called in the validation loop at the very end of the epoch.
- Return type