Facing various issues with validation loop when using IterableDataset that implements __len__

Hello all,

I’m trying to train a neural network with a tabular Parquet dataset which cannot fit into memory. As a solution, I’ve been using PyArrow to load one row at a time, leaving the Dataloader to handle batching. I decided to wrap this in a Pytorch IterableDataset which implements the __len__ method, so that Lightning’s Trainer can maintain the notion of an epoch.

I’ll post code below, but broadly speaking, my structure is to create a generator for each parquet file using pyarrow’s iter_batches method. I chain these together using itertools.chain and return an iterator for resulting chain in the __iter__ method of a pytorch IterableDataset.

In addition, as described in Pytorch’s documentation, I use torch.utils.data.get_worker_info to assign each worker a subset of these parquet file generators, to avoid redundant data. I implement __len__ by iterating over the parquet files in my dataset and adding up the number of rows in each

Here’s the code:

import pyarrow.parquet as pq
from pyarrow.fs import S3FileSystem
from torch.utils.data import Dataset, IterableDataset, get_worker_info
from itertools import chain
import boto3

class MyDataset(IterableDataset):
    def __init__(files):

        self.all_gens = [single_parquet_gen(f) for f in files]
        self.length = get_parquet_dataset_length(bucket, files)

    def __len__(self):
        return self.length

    def __iter__(self):
        if get_worker_info() is not None:
            num_workers = get_worker_info().num_workers
            worker_id = get_worker_info().id

            # Iterable dataloaders must split up the iterables to avoid repeating data
            # Assign each worker to len(self.all_gens) / num_workers parquet files

            pqs_this_worker = assign_subset_to_worker(self.all_gens, num_workers, worker_id)
            chain_this_worker = chain(*pqs_this_worker)

            return iter(chain_this_worker)
            chained = chain(*self.all_gens)
            return iter(chained)

def single_parquet_gen(parquet_file):
    s3 = boto3.client('s3')
    obj = s3.get_object(Bucket=bucket, Key=key)
    parquet_file = pq.ParquetFile(io.BytesIO(obj['Body'].read()))
    gen = parquet_file.iter_batches(batch_size=bs)
    for batch in gen:
        yield format_data(batch.to_pylist()[0])

def format_data(data):
    <This part formats the tabular data to the shape/type my Pytorch network expects>

def get_parquet_dataset_length(files):
    dataset = pq.ParquetDataset(paths, filesystem=s3)
    nrow = 0
    for fragment in dataset.fragments:
        nrow += fragment.metadata.num_rows
    return nrow

def assign_subset_to_worker(files, num_splits, split_id):
    n = len(to_split)
    size = n // num_splits  # Initialize with equal sizes
    remainder = n % num_splits
    # Distribute the remainder items among workers 0:remainder
    start_idx = split_id*size + min(remainder, split_id)
    end_idx = start_idx + size + (1 if split_id < remainder else 0)
    return files[start_idx:end_idx]

class MyDataModule():

    def __init__(self, directory, bs, num_workers):
        self.train_files = get_s3_uris(directory+'/train')
        self.val_files = get_s3_uris(directory+'/val')
        self.test_files = get_s3_uris(directory+'/test')

    def setup(self):
        self.train_data = MyDataset(self.train_files)
        self.val_data = MyDataset(self.val_files)
        self.test_data = MyDataset(self.test_files)

    def train_dataloader(self):
        return self.get_dataloader(self.train_data)

    def val_dataloader(self):
        return self.get_dataloader(self.val_data)

    def get_dataloader(self, dataset):
        return DataLoader(dataset, self. bs, num_workers=self.num_workers, drop_last=False, persistent_workers=num_workers > 0, pin_memory=False, sampler=None)

def get_s3_uris(bucket, prefix):
    s3 = boto3.client('s3')

    paginator = s3.get_paginator('list_objects_v2')
    pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
    uris = []

    for page in pages:
        for obj in page['Contents']:
            if obj['Key'].split('/')[-1][0] == '_':

    return uris

t = Trainer(check_val_every_n_epoch=1, max_epochs=2, reload_dataloaders_every_n_epochs=1)
t.fit(<some_model>, datamodule)

Using the datamodule in isolation, I’ve confirmed that the dataloader runs for the expected number of steps. However, when using the Trainer, I’ve experienced various issues with this setup:

  • First, my Trainer only made it through one train epoch. The next would have 0 steps. This was remedied by setting reload_dataloaders_every_n_epochs=1
  • However, each train epoch after the first does not run for ceiling(len(dataset) / batch_size) steps, as expected
  • The Trainer still only runs a single validation epoch. Afterwards, each train epoch goes right into the next
  • On top of this, the validation epoch skips 2 steps, plus another 2 per worker (i.e. if len(dataset) / batch_size = 40, and I use num_workers=2 in my dataloader, the validation epoch only goes for 34 steps according to the progress bar)
  • Moreover, validation does not run at all if I set num_workers = 0

Most of my testing was performed with num_workers=1 for simplicity, but all of the above issues still occur. Also, I’ve run this code using a standard Dataset which loads a subset of data into memory and everything worked fine, which has made me stop investigating format_data, get_s3_uris, or the parquet files themselves as a culprits.

Can anyone lend a hand? Am I doing something wrong, forgetting some setting, or is there an issue in how Lightning treats finite-length IterableDatasets?

(I had a thought that maybe the shortened val epochs have to do with sanity checking. These checks run two steps, but I’m not sure if they’re run once per process + once at the start of training. I don’t believe this would address train dataloaders, though, unless sanity checking has some hidden interaction with train dataloaders. I suspect some overall inconsistency in resetting iterators at the appropriate time.)

Want to add that I ran a simple loop in pure Pytorch using the same data, functions, and class, and did not have any of these issues