Hi
I have a dataset of type tf.data.dataset, which is iterable, but I cannot access each element of it with random access. I read the documentation here torch.utils.data — PyTorch 1.12 documentation it was not clear for me how I can write my own dataloader with iterable dataset, I appreciate assistance with providing me examples. In particular, here is the tutorial example:
- With pytorch lightening, do I need to set worker_info? Is this set automatically?
- In case I need, could you tell me how the
__iter__
should be written? when in this example it returnsiter(range(iter_start, iter_end))
I am not sure how this needs to be done for an iterable dataset? - The dataset needs to get split based on tutorial into multiple workers, how could I know how much workers are avilable in each case of TPU/multiple-gpus/ …
thanks
class MyIterableDataset(torch.utils.data.IterableDataset):
def __init__(self, start, end):
super(MyIterableDataset).__init__()
assert end > start, "this example code only works with end >= start"
self.start = start
self.end = end
def __iter__(self):
worker_info = torch.utils.data.get_worker_info()
if worker_info is None: # single-process data loading, return the full iterator
iter_start = self.start
iter_end = self.end
else: # in a worker process
# split workload
per_worker = int(math.ceil((self.end - self.start) / float(worker_info.num_workers)))
worker_id = worker_info.id
iter_start = self.start + worker_id * per_worker
iter_end = min(iter_start + per_worker, self.end)
return iter(range(iter_start, iter_end))