When I tried to use LightningDataModule, I first defined a DataModule like this:
class MultiSceneDataModuleNew(pl.LightningDataModule):
def __init__(self, args, config):
super().__init__()
self.data_root = config.DATASET.DATA_ROOT # Root path to EPIC-FIELDS dataset
self.batch_size = args.batch_size
self.num_workers = args.num_workers
self.pin_memory = getattr(args, "pin_memory", True)
def setup(self, stage: str = None):
if stage == "fit" or stage is None:
self.train_dataset = self._setup_dataset(mode="train")
self.val_dataset = self._setup_dataset(mode="val")
logger.info("Train & Val Dataset loaded!")
elif stage == "test" or stage is None:
self.test_dataset = self._setup_dataset(mode="test")
logger.info("Test Dataset loaded!")
def _setup_dataset(self, mode):
datasets = []
for scene_folder in tqdm(os.listdir(self.data_root), desc=f"Loading {mode} datasets"):
scene_path = os.path.join(self.data_root, scene_folder)
if os.path.isdir(scene_path):
json_path = os.path.join(self.data_root, f"{scene_folder}.json")
if os.path.exists(json_path):
datasets.append(EpicKitchensDataset(scene_path, json_path, mode))
return ConcatDataset(datasets)
def train_dataloader(self):
dataloader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=self.pin_memory)
return dataloader
def val_dataloader(self):
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory)
def test_dataloader(self):
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory)
I also defined a Dataset Class which implements __len__()
and __getitem__()
methods, like this:
class EpicKitchensDataset(Dataset):
def __init__(self, scene_folder, json_file, mode="train"):
self.scene_folder = scene_folder
with open(json_file, 'r') as f:
self.data = json.load(f)
self.cam_matrices, self.world_to_cam_poses = self.parse_camera_data(self.data['camera'])
self.transform = transforms.ToTensor()
self.depth_pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")
def parse_camera_data(self, camera_data):
# Extract and construct intrinsic and extrinsic matrices
fx, fy, cx, cy = camera_data['params'][:4]
intrinsic_matrix = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float32)
world_to_cam_poses = []
for pose in self.data['images'].values():
R = qvec2rotmat(pose[:4])
t = np.array(pose[4:]).reshape(3, 1)
extrinsic_matrix = np.hstack((R, t))
extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1]))
world_to_cam_poses.append(torch.tensor(extrinsic_matrix, dtype=torch.float32))
return intrinsic_matrix, torch.stack(world_to_cam_poses)
def __len__(self):
return len(self.data['images'])
def __getitem__(self, idx):
frame_name = list(self.data['images'].keys())[idx]
image_path = os.path.join(self.scene_folder, frame_name)
# img = plt.imread(image_path)
# img_tensor = self.transform(img) if img.dtype == np.float32 else self.transform(img / 255.0)
# Use torchvision.io.read_image to load the image directly as a tensor
if not os.path.exists(image_path):
print(f"Image not found: {image_path}")
else:
img = Image.open(image_path)
# img = Image.open(image_path)
# Apply transformations: resize and convert to tensor
img = self.transform(img)
if img.shape[1:] != (256, 456):
print(f"Resizing image {image_path} from {img.shape[1:]} to (256, 456)")
img = resize(img, [256, 456])
depth = self.depth_pipe(image_path)["depth"]
depth_tensor = self.transform(depth)
sample = {
"img": img,
"gt_depth": depth_tensor,
"intrinsics": self.cam_matrices,
"world_to_cam_pose": self.world_to_cam_poses[idx]
}
return sample
The datamodule works fine when I called like this:
data_module = MultiSceneDataModuleNew(args, config)
data_module.setup("fit")
print(data_module.train_dataset.__getitem__(0))
print(next(iter(data_module.train_dataloader())))
trainer.fit(model, datamodule=data_module)
However, I received the error below when I ran trainer.fit(model, datamodule=data_module)
:
Traceback (most recent call last):
File "/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 402, in _check_dataloader_iterable
iter(dataloader) # type: ignore[call-overload]
TypeError: 'MultiSceneDataModuleNew' object is not iterable
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 407, in _check_dataloader_iterable
raise TypeError(
TypeError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Found <src.lightning.data_new.MultiSceneDataModuleNew object at 0x7fe006450640>.
/python3.9/multiprocessing/resource_tracker.py:96: UserWarning: resource_tracker: process died unexpectedly, relaunching. Some resources might leak.
warnings.warn('resource_tracker: process died unexpectedly, '
I think the error was due to some error in the data module, but I cannot locate it. Could anyone give me some insight? Any help would be appreciated. Thanks!