Hi there!
In my collate_fn function, I am trying to return my own custom batch class instead of a list of tensors. Problem is that the Trainer will not send the CustomBatch tensors to the correct device
class CustomBatch:
def __init__(self,
x: torch.Tensor,
y: torch.Tensor,
pad_x: torch.Tensor,
pad_y: torch.Tensor):
self.x = x
self.y = y
self.pad_x = pad_x
self.pad_y = pad_y
def collate_fn(batch):
X = torch.cat([seq.x for seq in batch])
Y = torch.cat([seq.y for seq in batch])
pad_x = []
pad_y = []
for i, seq in enumerate(batch):
pad_x.extend([i] * seq.x.size(0))
pad_y.extend([i] * seq.y.size(0))
pad_x = torch.Tensor(pad_x).long()
pad_y = torch.Tensor(pad_y).long()
return CustomBatch(X, Y, pad_x, pad_y)
If instead, I return a list of tensors then it works fine.
return X, Y, pad_x, pad_y
Do add any special method to CustomBatch
so the Trainer
knows how to manage its content.
Best,
Arturo