I have recently moved from vanilla PyTorch
to Lightning
since I like very much how it organizes the code and especially the DataModule
. According to the docs Datamodules
are for you if you ever asked the questions (emphasis mine):
- what splits did you use?
- what transforms did you use?
- what normalization did you use?
- how did you prepare/tokenize the data?
I am interesting in the following scenario. I am given a dataset and I want to perform the following steps:
- Split the dataset into train, validation and test.
- Train the model and save it (no problem here, I can load any
checkpoint
). - Come back later (maybe after two days) and test the model on the test set from step (1).
One way to achieve that is to save the indices of the split from step (1). However, I think a more elegant solution is the documented one:
class MNISTDataModule(L.LightningDataModule):
def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
super().__init__()
self.data_dir = data_dir
self.batch_size = batch_size
def setup(self, stage: str):
self.mnist_test = MNIST(self.data_dir, train=False)
self.mnist_predict = MNIST(self.data_dir, train=False)
mnist_full = MNIST(self.data_dir, train=True)
self.mnist_train, self.mnist_val = random_split( ### This is what I want to maintain.
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
Basically, my problem boils down to saving the state of the DataModule
. I have read the Save DataModule state, but I can’t understand what is saved under the hood and how I am supposed to load it again back, if I want to perform inference.
Please note that I am not interested in the MNIST
example.