I have a solution for this, but I’m wondering if there is a better/more idiomatic way to accomplish the same thing.
Brief background:
I have a LightningModule
and would like to configure a torch.optim.lr_scheduler.OneCycleLR
learning rate scheduler. I instantiate my optimizer and scheduler in the configure_optimizers
method as usual. To create the OneCycleLR
instance, I need to know how large my training data is so I can pass the steps_per_epoch
and epochs
kwargs to the constructor.
I am using a DataModule
and passing it to the Trainer.fit
method, e.g.
trainer = Trainer(...)
trainer.fit(model=myPLModule, datamodule=myDataModule)
In this case, it’s quite simple (no distributed training, my DataModule
has only a single dataloader, etc.)
My question is: what is the most straightforward way to get the size of the training data? So far, I have been able to get this with the following:
def configure_optimizers(self):
train_dataloader = self.trainer.datamodule.train_dataloader()
train_size = len(train_dataloader.dataset)
batch_size = train_dataloader.batch_size
...
For this project, I will always be using a DataModule
, so I assume it’s safe to assume that self.trainer.datamodule
will be present, but is there a more general way to access training set details? For instance, if one passes the train_dataloaders
kwarg to Trainer.fit
, then I assume the snippet above would fail.
Thanks!