Training when data is stored in batches

Hello everyone, I suspect this is an obvious question but after reading the docs and exploring the forums I could not find a solution.

My situation: I have 40 million images that I want to use to learn a latent space using a self-supervised network. To soften the impact on the filesystem, I split them in 700 batches of around 50,000 images per batch. Each batch can have a different number of images.
I wrote my dataset such that each element returned by the dataset is a batch with shape (B,C,H,W).
The dataloader is using batch_size=1

Originally, I was planning to take each batch as a training batch, but doing the forward step with 50,000 images is too much for my RAM. Is there a way to split each loaded batch during training? (kind of the opposite of accumulating gradient)

class StampsDataset(Dataset):
    def __init__(self, filenames, img_transform=None):
        self.img_transform = img_transform
        self.filenames = filenames

    def __getitem__(self, idx):
        images = torch.from_numpy(
        if self.img_transform is not None:
            images = self.img_transform(images)

        return images

    def __len__(self):
        return len(filenames)

I could do this by manually writing the training loop, but I still want to make use of the lightning Trainer api.

Thank you in advance.

@ClarkGuilty I think the easiest is if you save each sample as its own file. Then it is very trivial to index into the folder of all files and let the DataLoader batch your samples together. Then you can dynamically change the batch size depending on how much memory you have available. I think any other workaround is too inflexible and quite hard maintain, Iā€™m not sure if it would be worth it.

Thank you for your answer.

Sadly, my files are small but many (37 million). So, they put too much strain on the filesystem (+ they are in a remote server). I ended up writing the dataset such that it loads the batch when needed, but returns only one image at the time, and only loads the new batch when needed. It takes 0.4s to load one image, and 2.4 seconds to load a new batch + a new image.

I include the code for completion.

class MoreEfficientStampsDataset(Dataset):
    def __init__(self, filenames, df, data_path, img_transform=None, shuffle=False):
        self.img_transform = img_transform
        self.filenames = filenames
        self.loaded_file = None
        self.df = df.loc[filenames]
        self._index = df.loc[filenames].index
        self.data_path = data_path

    def __getitem__(self, idx):
        filename, local_idx = self._index[idx]
        if filename != self.loaded_file:
        image = self._loaded_data[local_idx].reshape(1,66,66).float()
        if self.img_transform is not None:
            image = self.img_transform(image)
        return image, idx

    def __len__(self):
        return len(self.df)
    def _load_file(self, filename):
        loaded_data = np.load(
        if self.shuffle:
        self._loaded_data = torch.from_numpy(loaded_data)
        self.loaded_file = filename
1 Like