Hello all,
I’m trying to train a neural network with a tabular Parquet dataset which cannot fit into memory. As a solution, I’ve been using PyArrow to load one row at a time, leaving the Dataloader to handle batching. I decided to wrap this in a Pytorch IterableDataset
which implements the __len__
method, so that Lightning’s Trainer can maintain the notion of an epoch.
I’ll post code below, but broadly speaking, my structure is to create a generator for each parquet file using pyarrow’s iter_batches
method. I chain these together using itertools.chain
and return an iterator for resulting chain in the __iter__
method of a pytorch IterableDataset
.
In addition, as described in Pytorch’s documentation, I use torch.utils.data.get_worker_info
to assign each worker a subset of these parquet file generators, to avoid redundant data. I implement __len__
by iterating over the parquet files in my dataset and adding up the number of rows in each
Here’s the code:
import pyarrow.parquet as pq
from pyarrow.fs import S3FileSystem
from torch.utils.data import Dataset, IterableDataset, get_worker_info
from itertools import chain
import boto3
class MyDataset(IterableDataset):
def __init__(files):
self.all_gens = [single_parquet_gen(f) for f in files]
self.length = get_parquet_dataset_length(bucket, files)
def __len__(self):
return self.length
def __iter__(self):
if get_worker_info() is not None:
num_workers = get_worker_info().num_workers
worker_id = get_worker_info().id
# Iterable dataloaders must split up the iterables to avoid repeating data
# Assign each worker to len(self.all_gens) / num_workers parquet files
pqs_this_worker = assign_subset_to_worker(self.all_gens, num_workers, worker_id)
chain_this_worker = chain(*pqs_this_worker)
return iter(chain_this_worker)
else:
chained = chain(*self.all_gens)
return iter(chained)
def single_parquet_gen(parquet_file):
s3 = boto3.client('s3')
obj = s3.get_object(Bucket=bucket, Key=key)
parquet_file = pq.ParquetFile(io.BytesIO(obj['Body'].read()))
gen = parquet_file.iter_batches(batch_size=bs)
for batch in gen:
yield format_data(batch.to_pylist()[0])
def format_data(data):
<This part formats the tabular data to the shape/type my Pytorch network expects>
def get_parquet_dataset_length(files):
dataset = pq.ParquetDataset(paths, filesystem=s3)
nrow = 0
for fragment in dataset.fragments:
nrow += fragment.metadata.num_rows
return nrow
def assign_subset_to_worker(files, num_splits, split_id):
n = len(to_split)
size = n // num_splits # Initialize with equal sizes
remainder = n % num_splits
# Distribute the remainder items among workers 0:remainder
start_idx = split_id*size + min(remainder, split_id)
end_idx = start_idx + size + (1 if split_id < remainder else 0)
return files[start_idx:end_idx]
class MyDataModule():
def __init__(self, directory, bs, num_workers):
self.train_files = get_s3_uris(directory+'/train')
self.val_files = get_s3_uris(directory+'/val')
self.test_files = get_s3_uris(directory+'/test')
def setup(self):
self.train_data = MyDataset(self.train_files)
self.val_data = MyDataset(self.val_files)
self.test_data = MyDataset(self.test_files)
def train_dataloader(self):
return self.get_dataloader(self.train_data)
def val_dataloader(self):
return self.get_dataloader(self.val_data)
def get_dataloader(self, dataset):
return DataLoader(dataset, self. bs, num_workers=self.num_workers, drop_last=False, persistent_workers=num_workers > 0, pin_memory=False, sampler=None)
def get_s3_uris(bucket, prefix):
s3 = boto3.client('s3')
paginator = s3.get_paginator('list_objects_v2')
pages = paginator.paginate(Bucket=bucket, Prefix=prefix)
uris = []
for page in pages:
for obj in page['Contents']:
#print(obj['Key'].split('/')[-1][0])
if obj['Key'].split('/')[-1][0] == '_':
continue
uris.append(obj['Key'])
return uris
t = Trainer(check_val_every_n_epoch=1, max_epochs=2, reload_dataloaders_every_n_epochs=1)
t.fit(<some_model>, datamodule)
Using the datamodule in isolation, I’ve confirmed that the dataloader runs for the expected number of steps. However, when using the Trainer, I’ve experienced various issues with this setup:
- First, my Trainer only made it through one train epoch. The next would have 0 steps. This was remedied by setting reload_dataloaders_every_n_epochs=1
- However, each train epoch after the first does not run for
ceiling(len(dataset) / batch_size)
steps, as expected - The Trainer still only runs a single validation epoch. Afterwards, each train epoch goes right into the next
- On top of this, the validation epoch skips 2 steps, plus another 2 per worker (i.e. if
len(dataset) / batch_size = 40
, and I usenum_workers=2
in my dataloader, the validation epoch only goes for 34 steps according to the progress bar) - Moreover, validation does not run at all if I set
num_workers = 0
Most of my testing was performed with num_workers=1 for simplicity, but all of the above issues still occur. Also, I’ve run this code using a standard Dataset which loads a subset of data into memory and everything worked fine, which has made me stop investigating format_data
, get_s3_uris
, or the parquet files themselves as a culprits.
Can anyone lend a hand? Am I doing something wrong, forgetting some setting, or is there an issue in how Lightning treats finite-length IterableDataset
s?
(I had a thought that maybe the shortened val epochs have to do with sanity checking. These checks run two steps, but I’m not sure if they’re run once per process + once at the start of training. I don’t believe this would address train dataloaders, though, unless sanity checking has some hidden interaction with train dataloaders. I suspect some overall inconsistency in resetting iterators at the appropriate time.)