Support for PyTorchData - Dataloader2 Multiprocessing Issue

Hi all,

I’ve been exploring TorchData, and trying to get DataPipe and Dataloader2 to work in tandem with lightning. However, I’ve encountered an issue with the multi-processing and distributed dataloading aspect of the dataloaders2. I was wondering if anyone have experience with getting lightning and torchdata working together and know where lightning is with support for torchdata.

I have seen two issues on the lightning github repo about support for torchdata [Datapipe handling with more than 1 GPU gives torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with exit code 1 · Issue #13039 · Lightning-AI/lightning · GitHub] and Support setup for torch DataPipes · Issue #16603 · Lightning-AI/lightning · GitHub but haven’t seen anything come out of it yet.

I have played around with a small minimal working example but found some problems when using multi process reading service and sharding of the datapipe. One of the problems is that the trainer doesn’t run the dataloader2’s shutdown function which results in unresolved processes which I have to manually shutdown from terminal. The error looks like this

[rank: 0] Received SIGTERM: 15
[rank: 0] Received SIGTERM: 15
[rank: 0] Received SIGTERM: 15

I have made a workaround by making the dataloader2 a property and through the datamodule calling the shudown function of the dataloader2 after training. This works for the minimal working example but when I try to apply this to my other project i get another error. I have tried using the trainer.train_dataloader.shutdown() but i just get a 'NoneType' object has no attribute 'shutdown’

Here my minimal working example:

from typing import Any, Dict, Optional
import torch
from torch import nn
import torch.optim as optim
from lightning import LightningDataModule, LightningModule, Trainer
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingFilter, Shuffler, Mapper, RandomSplitter
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService

def test_func(x):
    data, label = x
    data = data*2 
    return (data, label)

class TestDataModule(LightningDataModule):
    def __init__(self, num_workers):
        super().__init__()
        self.num_workers = num_workers
        self.data_size = 2000
        self.feature_size = 100

        self.train_datapipe: Optional[IterDataPipe] = None
        self.val_datapipe: Optional[IterDataPipe] = None
        self.test_datapipe: Optional[IterDataPipe] = None

        self.rs: Optional[MultiProcessingReadingService] = None

        self.train_dataloader2: Optional[DataLoader2] = None
        self.val_dataloader2: Optional[DataLoader2] = None
        self.test_dataloader2: Optional[DataLoader2] = None

    def prepare_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        if not self.train_datapipe or not self.val_datapipe or not self.test_datapipe:
            data = torch.randn(self.data_size, self.feature_size)
            labels = torch.randint(0, 2, (self.data_size,))

            self.train_datapipe = IterableWrapper(zip(data,labels)) \
                .shuffle() \
                .sharding_filter() \
                .map(test_func)
            self.test_datapipe = IterableWrapper(zip(data,labels)) \
                .shuffle() \
                .sharding_filter() \
                .map(test_func)
            self.val_datapipe = IterableWrapper(zip(data,labels)) \
                .shuffle() \
                .sharding_filter() \
                .map(test_func)
        if not self.rs:
            self.rs = MultiProcessingReadingService(num_workers=self.num_workers)

        if not self.train_dataloader2 or not self.val_dataloader2 or not self.test_dataloader2:
            self.val_dataloader2 = DataLoader2(
                self.val_datapipe,
                reading_service=self.rs,
            )
            self.train_dataloader2 = DataLoader2(
                self.train_datapipe,
                reading_service=self.rs,
            )
            self.test_dataloader2 = DataLoader2(
                self.test_datapipe,
                reading_service=self.rs,
            )

    def train_dataloader(self):
        return self.train_dataloader2

    def val_dataloader(self):
        return self.val_dataloader2


    def test_dataloader(self):
        return self.test_dataloader2

# Define the model
class SimpleNN(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = nn.Linear(100, 1)

    def forward(self, x):
        return self.layer(x).squeeze()

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.BCEWithLogitsLoss()(y_hat, y.float())
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.BCEWithLogitsLoss()(y_hat, y.float())
        acc = ((y_hat > 0).float() == y.float()).float().mean()
        metrics = {"val_loss": loss, "val_acc": acc}
        self.log_dict(metrics)
        return metrics

    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)

# Training the model
model = SimpleNN()
dm = TestDataModule( num_workers=2)
trainer = Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
trainer.fit(model, dm)
dm.val_dataloader2.shutdown()
dm.test_dataloader2.shutdown()
dm.train_dataloader2.shutdown()

I’m curious to know if anyone else has tried using Lightning and TorchData with multi-process/distributed reading service and would be interested in discussing their experiences and potential solutions to get it working smoothly.

Best regards,

Moust Holmes