decorators¶
Functions
Decorator for |
|
Validates that the module parameter lengths match after moving to the device. |
Decorator for LightningModule methods.
- pytorch_lightning.core.decorators.auto_move_data(fn)[source]¶
Decorator for
LightningModule
methods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance ofLightningModule
and is typically applied to__call__
orforward
.- Parameters
fn¶ (
Callable
) – A LightningModule method for which the arguments should be moved to the device the parameters are on.- Return type
Example:
# directly in the source code class LitModel(LightningModule): @auto_move_data def forward(self, x): return x # or outside LitModel.forward = auto_move_data(LitModel.forward) model = LitModel() model = model.to('cuda') model(torch.zeros(1, 3)) # input gets moved to device # tensor([[0., 0., 0.]], device='cuda:0')
- pytorch_lightning.core.decorators.parameter_validation(fn)[source]¶
Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU’s.
Note
TPU’s require weights to be tied/shared after moving the module to the device. Failure to do this results in the initialization of new weights which are not tied. To overcome this issue, weights should be tied using the
on_post_move_to_device
model hook which is called after the module has been moved to the device.See also