DeviceDtypeModuleMixin¶
- class pytorch_lightning.core.mixins.DeviceDtypeModuleMixin[source]¶
Bases:
torch.nn.modules.module.Module
Initializes internal Module state, shared by both nn.Module and ScriptModule.
- cuda(device=None)[source]¶
Moves all model parameters and buffers to the GPU. This also makes associated parameters and buffers different objects. So it should be called before constructing optimizer if the module will live on GPU while being optimized.
- double()[source]¶
Casts all floating point parameters and buffers to
double
datatype.- Returns
self
- Return type
Module
- float()[source]¶
Casts all floating point parameters and buffers to
float
datatype.- Returns
self
- Return type
Module
- half()[source]¶
Casts all floating point parameters and buffers to
half
datatype.- Returns
self
- Return type
Module
- to(*args, **kwargs)[source]¶
Moves and/or casts the parameters and buffers.
This can be called as .. function:: to(device=None, dtype=None, non_blocking=False) .. function:: to(dtype, non_blocking=False) .. function:: to(tensor, non_blocking=False) Its signature is similar to
torch.Tensor.to()
, but only accepts floating point desireddtype
s. In addition, this method will only cast the floating point parameters and buffers todtype
(if given). The integral parameters and buffers will be moveddevice
, if that is given, but with dtypes unchanged. Whennon_blocking
is set, it tries to convert/move asynchronously with respect to the host if possible, e.g., moving CPU Tensors with pinned memory to CUDA devices. See below for examples.Note
This method modifies the module in-place.
- Parameters
- Returns
self
- Return type
Module
- Example::
>>> class ExampleModule(DeviceDtypeModuleMixin): ... def __init__(self, weight: torch.Tensor): ... super().__init__() ... self.register_buffer('weight', weight) >>> _ = torch.manual_seed(0) >>> module = ExampleModule(torch.rand(3, 4)) >>> module.weight tensor([[...]]) >>> module.to(torch.double) ExampleModule() >>> module.weight tensor([[...]], dtype=torch.float64) >>> cpu = torch.device('cpu') >>> module.to(cpu, dtype=torch.half, non_blocking=True) ExampleModule() >>> module.weight tensor([[...]], dtype=torch.float16) >>> module.to(cpu) ExampleModule() >>> module.weight tensor([[...]], dtype=torch.float16) >>> module.device device(type='cpu') >>> module.dtype torch.float16