Source code for pytorch_lightning.core.decorators
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Decorator for LightningModule methods."""
from functools import wraps
from typing import Callable
from pytorch_lightning.utilities import rank_zero_deprecation, rank_zero_warn
[docs]def auto_move_data(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.core.lightning.LightningModule` methods for which
input arguments should be moved automatically to the correct device.
It as no effect if applied to a method of an object that is not an instance of
:class:`~pytorch_lightning.core.lightning.LightningModule` and is typically applied to ``__call__``
or ``forward``.
Args:
fn: A LightningModule method for which the arguments should be moved to the device
the parameters are on.
Example::
# directly in the source code
class LitModel(LightningModule):
@auto_move_data
def forward(self, x):
return x
# or outside
LitModel.forward = auto_move_data(LitModel.forward)
model = LitModel()
model = model.to('cuda')
model(torch.zeros(1, 3))
# input gets moved to device
# tensor([[0., 0., 0.]], device='cuda:0')
"""
@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
from pytorch_lightning.core.lightning import LightningModule
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)
args, kwargs = self.transfer_batch_to_device((args, kwargs), device=self.device, dataloader_idx=None)
return fn(self, *args, **kwargs)
rank_zero_deprecation(
"The `@auto_move_data` decorator is deprecated in v1.3 and will be removed in v1.5."
f" Please use `trainer.predict` instead for inference. The decorator was applied to `{fn.__name__}`"
)
return auto_transfer_args
[docs]def parameter_validation(fn: Callable) -> Callable:
"""
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.
Args:
fn: ``model_to_device`` method
Note:
TPU's require weights to be tied/shared after moving the module to the device.
Failure to do this results in the initialization of new weights which are not tied.
To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
which is called after the module has been moved to the device.
See Also:
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
"""
@wraps(fn)
def inner_fn(self, *args, **kwargs):
pre_layer_count = len(list(self.model.parameters()))
module = fn(self, *args, **kwargs)
self.model.on_post_move_to_device()
post_layer_count = len(list(self.model.parameters()))
if not pre_layer_count == post_layer_count:
rank_zero_warn(
f"The model layers do not match after moving to the target device."
" If your model employs weight sharing on TPU,"
" please tie your weights using the `on_post_move_to_device` model hook.\n"
f"Layer count: [Before: {pre_layer_count} After: {post_layer_count}]"
)
return module
return inner_fn