Access datamodule from within LightningModule

I have a solution for this, but I’m wondering if there is a better/more idiomatic way to accomplish the same thing.

Brief background:
I have a LightningModule and would like to configure a torch.optim.lr_scheduler.OneCycleLR learning rate scheduler. I instantiate my optimizer and scheduler in the configure_optimizers method as usual. To create the OneCycleLR instance, I need to know how large my training data is so I can pass the steps_per_epoch and epochs kwargs to the constructor.

I am using a DataModule and passing it to the method, e.g.

trainer = Trainer(...), datamodule=myDataModule)

In this case, it’s quite simple (no distributed training, my DataModule has only a single dataloader, etc.)

My question is: what is the most straightforward way to get the size of the training data? So far, I have been able to get this with the following:

def configure_optimizers(self):
    train_dataloader = self.trainer.datamodule.train_dataloader()
    train_size = len(train_dataloader.dataset)
    batch_size = train_dataloader.batch_size

For this project, I will always be using a DataModule, so I assume it’s safe to assume that self.trainer.datamodule will be present, but is there a more general way to access training set details? For instance, if one passes the train_dataloaders kwarg to, then I assume the snippet above would fail.