Shortcuts

decorators

Functions

auto_move_data

Decorator for LightningModule methods for which input arguments should be moved automatically to the correct device.

parameter_validation

Validates that the module parameter lengths match after moving to the device.

Decorator for LightningModule methods.

pytorch_lightning.core.decorators.auto_move_data(fn)[source]

Decorator for LightningModule methods for which input arguments should be moved automatically to the correct device. It as no effect if applied to a method of an object that is not an instance of LightningModule and is typically applied to __call__ or forward.

Parameters

fn (Callable) – A LightningModule method for which the arguments should be moved to the device the parameters are on.

Return type

Callable

Example:

# directly in the source code
class LitModel(LightningModule):

    @auto_move_data
    def forward(self, x):
        return x

# or outside
LitModel.forward = auto_move_data(LitModel.forward)

model = LitModel()
model = model.to('cuda')
model(torch.zeros(1, 3))

# input gets moved to device
# tensor([[0., 0., 0.]], device='cuda:0')
pytorch_lightning.core.decorators.parameter_validation(fn)[source]

Validates that the module parameter lengths match after moving to the device. It is useful when tying weights on TPU’s.

Parameters

fn (Callable) – model_to_device method

Return type

Callable

Note

TPU’s require weights to be tied/shared after moving the module to the device. Failure to do this results in the initialization of new weights which are not tied. To overcome this issue, weights should be tied using the on_post_move_to_device model hook which is called after the module has been moved to the device.