Answer from slack (thanks @awaelchli):
Your code was the following:
import torch.nn as nn
import torch
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
class LSTMEncoder(nn.Module):
""" simple wrapper for a bi-lstm """
def __init__(
self,
emb_dim: int,
hidden_dim: int,
layers: int,
bidirectional: bool,
dropout: float,
pack=False,
):
super(LSTMEncoder, self).__init__()
self.num_directions = 2 if bidirectional else 1
self.lstm = nn.LSTM(
emb_dim,
hidden_dim // self.num_directions,
layers,
bidirectional=bidirectional,
batch_first=True,
dropout=dropout,
)
self.pack = pack
def init_state(self, input):
batch_size = input.size(0) # retrieve dynamically for decoding
h0 = torch.zeros(
self.lstm.num_layers * self.num_directions,
batch_size,
self.lstm.hidden_size,
)
c0 = torch.zeros(
self.lstm.num_layers * self.num_directions,
batch_size,
self.lstm.hidden_size,
)
return h0, c0
def forward(self, src_embedding, srclens, srcmask, temp=1):
h0, c0 = self.init_state(src_embedding)
if self.pack:
inputs = pack_padded_sequence(src_embedding, srclens, batch_first=True)
else:
inputs = src_embedding
outputs, (h_final, c_final) = self.lstm(inputs, (h0, c0))
if self.pack:
outputs, _ = pad_packed_sequence(outputs, batch_first=True)
# outputs: batch, seq_len num_directions * hidden_size
# h_n: num_layers * num_directions, batch, hidden_size
# c_n: num_layers * num_directions, batch, hidden_size
return outputs, (h_final, c_final)
Lightning will automatically move parameter weights from the model and data from the DataLoaders to the appropriate device, but it cannot move tensors created in the forward()
function. The fix here is easy:
h0 = torch.zeros(..., device=input.device)
c0 = torch.zeros(..., device=input.device)
By specifying a the appropriate device for the new tensors, you won’t run into this issue!