There is a number of issues that I’ve encountered when trying to ensure deterministic results and reproducibility from checkpoint. For example, given a training session that runs for say 10 epochs if I re-run it from one of the epochs the results will differ (train loss). This is despite setting everything as advised on Lightning trainer side in the reproducibility and deterministic section. I went through the code to determine the issues and here are my findings:
Seeding random generators (
torch.random) with the same seed such that
seed = xwhen resuming from checkpoint is not enough. States of random generators (
np.random.get_state()etc.) change as they’re being used, therefore, if you start an initial training, seed everything with
xand then call e.g.
np.random.random()(or another random generator for that matter) its state will change. Then, if you resume training from checkpoint at epoch
nand again seed at the beginning with
xall generators will have THE SAME state as they had at the beginning of the initial training, hence, DIFFERENT than at epoch
nfrom which the training was resumed. My solution to this was to modify this dictionary to include the states of random generator such that:
_checkpoint["random_state"] = random.getstate()
_checkpoint["np_random_state"] = np.random.get_state()
_checkpoint["torch_random_state"] = torch.random.get_rng_state()
Then, whilst restoring the state in
self.restore_loops(sorry can’t post more links due to being new user…) I’ve added the relevant code to load those states and it solved it, mostly… The
random.randomseemed to not work as intended (note that the values generated and their states when checkpoint was loaded were exactly the same as states and values generated when saving them during initial training), my assumption is that somewhere in the codebase those generators must be called a different number of times whilst advancing to next epoch vs whilst resuming from checkpoint. Fortunately, numpy seems to work just fine (I guess it’s never used in Lightning/PyTorch) and so my suggestion is to stick to
np.random.random()to ensure that you ALWAYS get the same values when state is loaded.
Next, at this point if you are not shuffling your data or setting
num_workers > 0you should be good to go. Setting
shuffle=Truein data loader is another problem. First, if you set it to
RandomSampler. All good until you get into
__iter__which calls the following
seed = int(torch.empty((), dtype=torch.int64).random_().item())if the
generator=None. Now, as I’ve mentioned in the first step, restoring state of
torch.randomdidn’t help here, as I said I assume it’s state must change differently when advancing from one epoch to another vs when starting from checkpoint and thus that number will not be the same and so will your batch. My initial solution of course was to try to supply
generatorobject and I thought it will solve it which it didn’t. Reason is relatively simple but annoying to find and again the problem is that if your loader is using
RandomSamplerand you give it a
generatorthen its state as it gets to
__iter__will be different when advancing to next epoch vs when you starts from checkpoint. And this is due to the fact that when starting training Lightning will call
FitLoopwhich will call
_check_dataloader_iterableand this in turn calls
iter(dataloader). This happens once again I think somewhere in
advance()so essentially when you advance to next epoch
iter(dataloder)is called once but when you start from checkpoint the
iter(dataloader)is called at least twice as it needs to setup the data. Therefore, the state of
generatorwill be different when you get to
RandomSampler.__iter___. Fixing this was relatively simple as you only need to override the
RandomSampler(don’t pass generator) and set a seed yourself to some value on the line where PyTorch calls
seed = int(torch.empty((), dtype=torch.int64).random_().item()).
Now, final issue is when you set
num_workers > 0. The
seed_everythingwhich is meant to deal with workers is not doing enough. You also need to define your
work_init_fnand pass it to data loader as you initialize it. The function needs to re-seed number of things, the fix can be found here How to fix all workers' seed via worker_init_fn for every iter? - #2 by ptrblck - vision - PyTorch Forums.
Doing all those, as well as following guidelines on the Lightning page will ensure that if you start from checkpoint your training run will produce THE SAME results. This, especially point 3 which was already discussed on PyTorch forum needs fixing.
Hope it helps anyone who’s struggling with this issues and I’m looking forward to Lightning fixing this.