What is the proper way to train a model, save it and then test it, avoiding information leakage and guaranteeing reproducibility?

I have recently moved from vanilla PyTorch to Lightning since I like very much how it organizes the code and especially the DataModule. According to the docs Datamodules are for you if you ever asked the questions (emphasis mine):

  • what splits did you use?
  • what transforms did you use?
  • what normalization did you use?
  • how did you prepare/tokenize the data?

I am interesting in the following scenario. I am given a dataset and I want to perform the following steps:

  1. Split the dataset into train, validation and test.
  2. Train the model and save it (no problem here, I can load any checkpoint).
  3. Come back later (maybe after two days) and test the model on the test set from step (1).

One way to achieve that is to save the indices of the split from step (1). However, I think a more elegant solution is the documented one:

class MNISTDataModule(L.LightningDataModule):
    def __init__(self, data_dir: str = "path/to/dir", batch_size: int = 32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage: str):
        self.mnist_test = MNIST(self.data_dir, train=False)
        self.mnist_predict = MNIST(self.data_dir, train=False)
        mnist_full = MNIST(self.data_dir, train=True)
        self.mnist_train, self.mnist_val = random_split(          ### This is what I want to maintain.
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

Basically, my problem boils down to saving the state of the DataModule. I have read the Save DataModule state, but I can’t understand what is saved under the hood and how I am supposed to load it again back, if I want to perform inference.

Please note that I am not interested in the MNIST example.

If you want to maintain the same split created by

self.mnist_train, self.mnist_val = random_split(          ### This is what I want to maintain.
            mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
        )

Then hardcoding the seed there like it is done will already be enough. It will always create the exact same split.

In general, if you want to save state of the data module, implement the two methods state_dict() and load_state_dict():

https://lightning.ai/docs/pytorch/stable/data/datamodule.html#state-dict
https://lightning.ai/docs/pytorch/stable/data/datamodule.html#load-state-dict

For example:

def state_dict(self):
    return {"seed": self.seed}

def load_state_dict(self, state_dict):
    self.seed = state_dict["seed"]

If you use the Trainer, these methods will automatically be called for saving and loading the data module state from the checkpoint.

So my DataModule must look like this?

class DataModule(L.LightningDataModule):
   def __init__(self, seed):
        self.seed = seed
    # Define the other methods.