DataLoader not iterable error

When I tried to use LightningDataModule, I first defined a DataModule like this:

class MultiSceneDataModuleNew(pl.LightningDataModule):
    def __init__(self, args, config):
        super().__init__()
        self.data_root = config.DATASET.DATA_ROOT  # Root path to EPIC-FIELDS dataset
        self.batch_size = args.batch_size
        self.num_workers = args.num_workers
        self.pin_memory = getattr(args, "pin_memory", True)

    def setup(self, stage: str = None):
        if stage == "fit" or stage is None:
            self.train_dataset = self._setup_dataset(mode="train")
            self.val_dataset = self._setup_dataset(mode="val")
            logger.info("Train & Val Dataset loaded!")

        elif stage == "test" or stage is None:
            self.test_dataset = self._setup_dataset(mode="test")
            logger.info("Test Dataset loaded!")

    def _setup_dataset(self, mode):
        datasets = []
        for scene_folder in tqdm(os.listdir(self.data_root), desc=f"Loading {mode} datasets"):
            scene_path = os.path.join(self.data_root, scene_folder)
            if os.path.isdir(scene_path):
                json_path = os.path.join(self.data_root, f"{scene_folder}.json")
                if os.path.exists(json_path):
                    datasets.append(EpicKitchensDataset(scene_path, json_path, mode))
        return ConcatDataset(datasets)

    def train_dataloader(self):
        dataloader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, pin_memory=self.pin_memory)
        return dataloader

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, pin_memory=self.pin_memory)

I also defined a Dataset Class which implements __len__() and __getitem__() methods, like this:

class EpicKitchensDataset(Dataset):
    def __init__(self, scene_folder, json_file, mode="train"):
        self.scene_folder = scene_folder
        with open(json_file, 'r') as f:
            self.data = json.load(f)

        self.cam_matrices, self.world_to_cam_poses = self.parse_camera_data(self.data['camera'])
        self.transform = transforms.ToTensor()
        self.depth_pipe = pipeline(task="depth-estimation", model="LiheYoung/depth-anything-small-hf")

    def parse_camera_data(self, camera_data):
        # Extract and construct intrinsic and extrinsic matrices
        fx, fy, cx, cy = camera_data['params'][:4]
        intrinsic_matrix = torch.tensor([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=torch.float32)

        world_to_cam_poses = []
        for pose in self.data['images'].values():
            R = qvec2rotmat(pose[:4])
            t = np.array(pose[4:]).reshape(3, 1)
            extrinsic_matrix = np.hstack((R, t))
            extrinsic_matrix = np.vstack((extrinsic_matrix, [0, 0, 0, 1]))
            world_to_cam_poses.append(torch.tensor(extrinsic_matrix, dtype=torch.float32))

        return intrinsic_matrix, torch.stack(world_to_cam_poses)

    def __len__(self):
        return len(self.data['images'])

    def __getitem__(self, idx):
        frame_name = list(self.data['images'].keys())[idx]
        image_path = os.path.join(self.scene_folder, frame_name)
        # img = plt.imread(image_path)
        # img_tensor = self.transform(img) if img.dtype == np.float32 else self.transform(img / 255.0)
        # Use torchvision.io.read_image to load the image directly as a tensor
       
        if not os.path.exists(image_path):
            print(f"Image not found: {image_path}")
        else:
            img = Image.open(image_path)
        # img = Image.open(image_path)

        # Apply transformations: resize and convert to tensor
        img = self.transform(img)
        if img.shape[1:] != (256, 456):
            print(f"Resizing image {image_path} from {img.shape[1:]} to (256, 456)")
            img = resize(img, [256, 456])
        

        depth = self.depth_pipe(image_path)["depth"]
        depth_tensor = self.transform(depth)

        sample = {
            "img": img,
            "gt_depth": depth_tensor,
            "intrinsics": self.cam_matrices,
            "world_to_cam_pose": self.world_to_cam_poses[idx]
        }
        
        return sample

The datamodule works fine when I called like this:

    data_module = MultiSceneDataModuleNew(args, config)
    data_module.setup("fit")
    print(data_module.train_dataset.__getitem__(0))
    print(next(iter(data_module.train_dataloader())))
    trainer.fit(model, datamodule=data_module)

However, I received the error below when I ran trainer.fit(model, datamodule=data_module):

Traceback (most recent call last):
  File "/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 402, in _check_dataloader_iterable
    iter(dataloader)  # type: ignore[call-overload]
TypeError: 'MultiSceneDataModuleNew' object is not iterable

During handling of the above exception, another exception occurred:
Traceback (most recent call last):
 File "/python3.9/site-packages/lightning/pytorch/trainer/connectors/data_connector.py", line 407, in _check_dataloader_iterable
    raise TypeError(
TypeError: An invalid dataloader was passed to `Trainer.fit(train_dataloaders=...)`. Found <src.lightning.data_new.MultiSceneDataModuleNew object at 0x7fe006450640>.
/python3.9/multiprocessing/resource_tracker.py:96: UserWarning: resource_tracker: process died unexpectedly, relaunching.  Some resources might leak.
  warnings.warn('resource_tracker: process died unexpectedly, '

I think the error was due to some error in the data module, but I cannot locate it. Could anyone give me some insight? Any help would be appreciated. Thanks! :slight_smile:

:face_with_head_bandage: :yum: Can anyone help?