i have the code:
class TrainDataset(Dataset):
......
def __getitem__(self, index):
......
out = {
'source_ids': src_ids,
'source_mask': src_mask,
'target_ids': target_ids,
'label': label
}
out_list.append(out)
return out_list
class DataModule(pl.LightningDataModule):
def prepare_data(self):
self.train = TrainDataset(args)
def train_dataloader(self):
train_loader = DataLoader(self.train,
batch_size=self.batch_size,
shuffle=True,
pin_memory=True,
num_workers=4)
return train_loader
traindataset will return a piece of data, same as [{a1},{a2}...{an}]
. When I set the batch_size to 2, dataloader will collect my data like this [[{a1},{b1}]...[{an},{bn}]]
, but what I expect is that he can help me process the data like this: [[{a1},{a2}...{an}],[{b1},{b2}...{bn}]]
.
i hope i made my question clear