Crash if numworkers>0

Hello everyone, I’ve got a really weird problem.
If I set my num_workers > 0 I get a crash in (or before?) the initial sanity check. The task manager shows that the disk usage is high and the RAM gets nearly completly filled up. I wonder if num_workers>0 somehow causes the numpy memmaps (~7GB each) to be loaded into memory even though this should normally not happen. I know that my dataclass and everything else works with num_workers=0 (did a fast_dev_run and some other tests where the RAM usage stayed low the entire time).

OS is Windows 10 Pro.
pytorch version: 2.0.1
pytorch lightning version: 2.0.2
pytorch -cuda 11.8


class MyDataset(Dataset):
    def __init__(self, parameter_settings: ps.Parameter, transform: torchvision.transforms = None):
        self.targets = torch.LongTensor(parameter_settings.get_labels())
        self.list_of_arrays = [
            np.memmap(memmap_path + ".dat", dtype='float32', mode='r', shape=hpf.read_memmap_shape(memmap_path)) for
            memmap_path in parameter_settings.get_data_path()]
        self.transform = transform

    def __getitem__(self, index):
        x = torch.from_numpy(np.stack([item[index] for item in self.list_of_arrays]))
        y = self.targets[index]
        if self.transform is not None:
            return self.transform(x), y
        return x, y

    def __len__(self):
        return len(self.targets)

Output that is printed to the console:

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name      | Type               | Params
0 | model     | ResNet             | 11.2 M
1 | train_acc | MulticlassAccuracy | 0     
2 | valid_acc | MulticlassAccuracy | 0     
3 | f1        | MulticlassF1Score  | 0     
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.733    Total estimated model params size (MB)
Traceback (most recent call last):
  File "C:\Users\x\PycharmProjects\cnn\", line 16, in <module>
  File "C:\Users\x\PycharmProjects\cnn\", line 94, in training_loop, data_module)
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\", line 520, in fit
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\", line 559, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\", line 935, in _run
    results = self._run_stage()
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\", line 978, in _run_stage
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\loops\", line 193, in run
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\loops\", line 235, in setup_data
    _check_dataloader_iterable(dl, source, trainer_fn)
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\lightning\pytorch\trainer\connectors\", line 383, in _check_dataloader_iterable
    iter(dataloader)  # type: ignore[call-overload]
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\torch\utils\data\", line 441, in __iter__
    return self._get_iterator()
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\torch\utils\data\", line 388, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "C:\Users\x\miniconda3\envs\cnn\lib\site-packages\torch\utils\data\", line 1042, in __init__
  File "C:\Users\x\miniconda3\envs\cnn\lib\multiprocessing\", line 121, in start
    self._popen = self._Popen(self)
  File "C:\Users\x\miniconda3\envs\cnn\lib\multiprocessing\", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "C:\Users\x\miniconda3\envs\cnn\lib\multiprocessing\", line 336, in _Popen
    return Popen(process_obj)
  File "C:\Users\x\miniconda3\envs\cnn\lib\multiprocessing\", line 93, in __init__
    reduction.dump(process_obj, to_child)
  File "C:\Users\x\miniconda3\envs\cnn\lib\multiprocessing\", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
OSError: [Errno 22] Invalid argument

Process finished with exit code 1


Pickler errors are a bit hard to parse but they usually mean that there is an attribute or function in your code that isn’t pickleable. Pickling happens when your program starts and then the dataloader creates new processes. It pickles all objects and instantiantes them again in the new process. This serialization has some limitations for certain objects that can’t be pickled.

It could simply be that in your dataset code, the memory mapped arrayes are not pickleable:

        self.list_of_arrays = [
            np.memmap(memmap_path + ".dat", dtype='float32', mode='r', shape=hpf.read_memmap_shape(memmap_path)) for
            memmap_path in parameter_settings.get_data_path()]

You could verify that by replacing that code with some dummy simulated data. If the pickling error goes away, then it is definitely this part that’s causing it.

Here is a quick hack that could work (I haven’t tested):

def __getitem__(self, index):
    if self.self.list_of_arrays is None:
         self.list_of_arrays = [ ...]  # move the code from __init__ to here
    # x, y = ...

This way, you’re memory mapping only on the first call to getitem which happens after pickling and spawning processes.

If that doesn’t work, you could consider loading the data into ram completely rather than memory mapping (depends on whether you have enough memory to fit the data). Or you could write your dataloader in a way to load each sample individually from disk, but that would require you to split the .dat files into multiple files.

Turned out that moving the list_of_arrays thing into the get_items instead of init solved the problem. It really seems like having it in init caused problems. I still have the suspicion that it tries to load the memmaps into memory which it is not supposed to do due to RAM limitations.
Thank you for the help.

1 Like