Another tip I found browsing the github issues is to prevent worker threads from being respawned each epoch, if you have short epochs this makes a big difference.
# Originally proposed by PetrochukM in https://github.com/pytorch/pytorch/issues/15849#issuecomment-518126031
# Modified by monoelh in https://github.com/PyTorchLightning/pytorch-lightning/issues/2875#issuecomment-673355304
class _RepeatSampler(object):
""" Sampler that repeats forever.
Args:
sampler (Sampler)
"""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
yield from iter(self.sampler)
class ContinuousDataLoader(torch.utils.data.DataLoader):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._DataLoader__initialized = False
self.batch_sampler = _RepeatSampler(self.batch_sampler)
self._DataLoader__initialized = True
self.iterator = super().__iter__()
def __len__(self):
return len(self.batch_sampler.sampler)
def __iter__(self):
for i in range(len(self)):
yield next(self.iterator)
See
DataLoader with option to re-use worker processes · Issue #15849 · pytorch/pytorch · GitHub and Significant Amount of Idle Time Not Part of Training or Validation Steps · Issue #2875 · Lightning-AI/lightning · GitHub for more details.