Hi all,
I’ve been exploring TorchData, and trying to get DataPipe and Dataloader2 to work in tandem with lightning. However, I’ve encountered an issue with the multi-processing and distributed dataloading aspect of the dataloaders2. I was wondering if anyone have experience with getting lightning and torchdata working together and know where lightning is with support for torchdata.
I have seen two issues on the lightning github repo about support for torchdata [Datapipe handling with more than 1 GPU gives torch.multiprocessing.spawn.ProcessExitedException: process 1 terminated with exit code 1 · Issue #13039 · Lightning-AI/lightning · GitHub] and Support setup for torch DataPipes · Issue #16603 · Lightning-AI/lightning · GitHub but haven’t seen anything come out of it yet.
I have played around with a small minimal working example but found some problems when using multi process reading service and sharding of the datapipe. One of the problems is that the trainer doesn’t run the dataloader2’s shutdown function which results in unresolved processes which I have to manually shutdown from terminal. The error looks like this
[rank: 0] Received SIGTERM: 15
[rank: 0] Received SIGTERM: 15
[rank: 0] Received SIGTERM: 15
I have made a workaround by making the dataloader2 a property and through the datamodule calling the shudown function of the dataloader2 after training. This works for the minimal working example but when I try to apply this to my other project i get another error. I have tried using the trainer.train_dataloader.shutdown()
but i just get a 'NoneType' object has no attribute 'shutdown’
Here my minimal working example:
from typing import Any, Dict, Optional
import torch
from torch import nn
import torch.optim as optim
from lightning import LightningDataModule, LightningModule, Trainer
from torchdata.datapipes.iter import IterableWrapper, IterDataPipe, ShardingFilter, Shuffler, Mapper, RandomSplitter
from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService
def test_func(x):
data, label = x
data = data*2
return (data, label)
class TestDataModule(LightningDataModule):
def __init__(self, num_workers):
super().__init__()
self.num_workers = num_workers
self.data_size = 2000
self.feature_size = 100
self.train_datapipe: Optional[IterDataPipe] = None
self.val_datapipe: Optional[IterDataPipe] = None
self.test_datapipe: Optional[IterDataPipe] = None
self.rs: Optional[MultiProcessingReadingService] = None
self.train_dataloader2: Optional[DataLoader2] = None
self.val_dataloader2: Optional[DataLoader2] = None
self.test_dataloader2: Optional[DataLoader2] = None
def prepare_data(self):
pass
def setup(self, stage: Optional[str] = None):
if not self.train_datapipe or not self.val_datapipe or not self.test_datapipe:
data = torch.randn(self.data_size, self.feature_size)
labels = torch.randint(0, 2, (self.data_size,))
self.train_datapipe = IterableWrapper(zip(data,labels)) \
.shuffle() \
.sharding_filter() \
.map(test_func)
self.test_datapipe = IterableWrapper(zip(data,labels)) \
.shuffle() \
.sharding_filter() \
.map(test_func)
self.val_datapipe = IterableWrapper(zip(data,labels)) \
.shuffle() \
.sharding_filter() \
.map(test_func)
if not self.rs:
self.rs = MultiProcessingReadingService(num_workers=self.num_workers)
if not self.train_dataloader2 or not self.val_dataloader2 or not self.test_dataloader2:
self.val_dataloader2 = DataLoader2(
self.val_datapipe,
reading_service=self.rs,
)
self.train_dataloader2 = DataLoader2(
self.train_datapipe,
reading_service=self.rs,
)
self.test_dataloader2 = DataLoader2(
self.test_datapipe,
reading_service=self.rs,
)
def train_dataloader(self):
return self.train_dataloader2
def val_dataloader(self):
return self.val_dataloader2
def test_dataloader(self):
return self.test_dataloader2
# Define the model
class SimpleNN(LightningModule):
def __init__(self):
super().__init__()
self.layer = nn.Linear(100, 1)
def forward(self, x):
return self.layer(x).squeeze()
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.BCEWithLogitsLoss()(y_hat, y.float())
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = nn.BCEWithLogitsLoss()(y_hat, y.float())
acc = ((y_hat > 0).float() == y.float()).float().mean()
metrics = {"val_loss": loss, "val_acc": acc}
self.log_dict(metrics)
return metrics
def configure_optimizers(self):
return optim.Adam(self.parameters(), lr=0.001)
# Training the model
model = SimpleNN()
dm = TestDataModule( num_workers=2)
trainer = Trainer(accelerator="gpu", devices=1, min_epochs=1, max_epochs=3, precision=16)
trainer.fit(model, dm)
dm.val_dataloader2.shutdown()
dm.test_dataloader2.shutdown()
dm.train_dataloader2.shutdown()
I’m curious to know if anyone else has tried using Lightning and TorchData with multi-process/distributed reading service and would be interested in discussing their experiences and potential solutions to get it working smoothly.
Best regards,
Moust Holmes