How to change the way dataloader handles data?

i have the code:

class TrainDataset(Dataset):
    def __getitem__(self, index):
            out = {
                'source_ids': src_ids,
                'source_mask': src_mask,
                'target_ids': target_ids,
                'label': label

        return out_list
class DataModule(pl.LightningDataModule):
    def prepare_data(self):
        self.train = TrainDataset(args)
    def train_dataloader(self):
        train_loader = DataLoader(self.train,
        return train_loader

traindataset will return a piece of data, same as [{a1},{a2}...{an}]. When I set the batch_size to 2, dataloader will collect my data like this [[{a1},{b1}]...[{an},{bn}]], but what I expect is that he can help me process the data like this: [[{a1},{a2}...{an}],[{b1},{b2}...{bn}]].
i hope i made my question clear :thinking:


I think for this it could be useful to implement a collate_fn function where you can define the concatenation of your data into a batch:

def collate_fn(samples):
    # samples is the list of samples returned from your
    # dataset, to be assembled into a batch
    # [[{a1},{a2}...{an}],[{b1},{b2}...{bn}]]
    return samples

dataloader = DataLoader(..., collate_fn=collate_fn)

Here are the PyTorch docs for this.

Hope this helps