If you want to maintain the same split created by
self.mnist_train, self.mnist_val = random_split( ### This is what I want to maintain.
mnist_full, [55000, 5000], generator=torch.Generator().manual_seed(42)
)
Then hardcoding the seed there like it is done will already be enough. It will always create the exact same split.
In general, if you want to save state of the data module, implement the two methods state_dict()
and load_state_dict()
:
https://lightning.ai/docs/pytorch/stable/data/datamodule.html#state-dict
https://lightning.ai/docs/pytorch/stable/data/datamodule.html#load-state-dict
For example:
def state_dict(self):
return {"seed": self.seed}
def load_state_dict(self, state_dict):
self.seed = state_dict["seed"]
If you use the Trainer, these methods will automatically be called for saving and loading the data module state from the checkpoint.