LightningModule.train_dataloader()

How do the hooks for the LightningModule interact with the hooks for the LightningDataModule?
Does one overwrite the other? Previously, I was able to call the LightningDataModule.train_dataloader() from within the LightningModule.train_dataloader() but it seems that the latter is not being called at all anymore when using trainer.fit(model, dm=datamodule)

My use case is that I’d like to modify the dataloader using a function that needs access to the optimizer, therefore I’d like to do it from the model:

class Classifier(LightningModule):
    def __init__(
        self,
        *args,
        **kwargs,
    ):
        super().__init__()
        # model initalized here

    def train_dataloader(self) -> Any:
        dl = self.trainer.datamodule.train_dataloader()
        if not hasattr(self.trainer.datamodule, "batch_size_physical"):
            return dl # just use the LightningDataModule as is
        # wrap using this function otherwise
        return wrap_data_loader(
            data_loader=dl,
            max_batch_size=self.trainer.datamodule.batch_size_physical,
            optimizer=self.optimizer,
        )

Hey

If you pass a LightningDataModule with methods train_dataloader() etc. implemented, they will be called instead of the ones in the LightningModule. This is useful if you want to decouple the data-related definitions from your model. In your case, you can’t do that because you said you have a dependency on the optimizer. So perhaps it’s best if you don’t use the DataModule and just implement the LightningModule hooks instead.

thanks for the response. Unfortunately, I’ve already tried that and it leads to an error as the trainer expects this to be implemented in the LightningDataModule

lightning.fabric.utilities.exceptions.MisconfigurationException: `train_dataloader` must be implemented to be used with the Lightning Trainer```

So you either implement train_dataloader() on the LightningDataModule and pass the data module to the Trainer, or you implement train_dataloader() on the LightningModule and pass in the LightningModule to the Trainer.

Thanks. So there is no way to use a datamodule and overwrite only a specific dataloader in the model? I’d need to implement the whole datamodule in the model?