Resuming from checkpoint gives different results

There is a number of issues that I’ve encountered when trying to ensure deterministic results and reproducibility from checkpoint. For example, given a training session that runs for say 10 epochs if I re-run it from one of the epochs the results will differ (train loss). This is despite setting everything as advised on Lightning trainer side in the reproducibility and deterministic section. I went through the code to determine the issues and here are my findings:

  1. Seeding random generators (random, np.random, torch.random) with the same seed such that seed = x when resuming from checkpoint is not enough. States of random generators (np.random.get_state() etc.) change as they’re being used, therefore, if you start an initial training, seed everything with x and then call e.g. np.random.random() (or another random generator for that matter) its state will change. Then, if you resume training from checkpoint at epoch n and again seed at the beginning with x all generators will have THE SAME state as they had at the beginning of the initial training, hence, DIFFERENT than at epoch n from which the training was resumed. My solution to this was to modify this dictionary to include the states of random generator such that:
    _checkpoint["random_state"] = random.getstate()
    _checkpoint["np_random_state"] = np.random.get_state()
    _checkpoint["torch_random_state"] = torch.random.get_rng_state()
    Then, whilst restoring the state in checkpoint_connector L292 in restore_training_state after self.restore_loops (sorry can’t post more links due to being new user…) I’ve added the relevant code to load those states and it solved it, mostly… The torch.random and random.random seemed to not work as intended (note that the values generated and their states when checkpoint was loaded were exactly the same as states and values generated when saving them during initial training), my assumption is that somewhere in the codebase those generators must be called a different number of times whilst advancing to next epoch vs whilst resuming from checkpoint. Fortunately, numpy seems to work just fine (I guess it’s never used in Lightning/PyTorch) and so my suggestion is to stick to np.random.random() to ensure that you ALWAYS get the same values when state is loaded.

  2. Next, at this point if you are not shuffling your data or setting num_workers > 0 you should be good to go. Setting shuffle=True in data loader is another problem. First, if you set it to True then DataLoader will use RandomSampler. All good until you get into __iter__ which calls the following seed = int(torch.empty((), dtype=torch.int64).random_().item()) if the generator=None. Now, as I’ve mentioned in the first step, restoring state of torch.random didn’t help here, as I said I assume it’s state must change differently when advancing from one epoch to another vs when starting from checkpoint and thus that number will not be the same and so will your batch. My initial solution of course was to try to supply generator object and I thought it will solve it which it didn’t. Reason is relatively simple but annoying to find and again the problem is that if your loader is using RandomSampler and you give it a generator then its state as it gets to __iter__ will be different when advancing to next epoch vs when you starts from checkpoint. And this is due to the fact that when starting training Lightning will call setup_data in FitLoop which will call _check_dataloader_iterable and this in turn calls iter(dataloader). This happens once again I think somewhere in advance() so essentially when you advance to next epoch iter(dataloder) is called once but when you start from checkpoint the iter(dataloader) is called at least twice as it needs to setup the data. Therefore, the state of generator will be different when you get to RandomSampler.__iter___. Fixing this was relatively simple as you only need to override the RandomSampler (don’t pass generator) and set a seed yourself to some value on the line where PyTorch calls seed = int(torch.empty((), dtype=torch.int64).random_().item()).

  3. Now, final issue is when you set num_workers > 0. The seed_everything which is meant to deal with workers is not doing enough. You also need to define your work_init_fn and pass it to data loader as you initialize it. The function needs to re-seed number of things, the fix can be found here How to fix all workers' seed via worker_init_fn for every iter? - #2 by ptrblck - vision - PyTorch Forums.

Doing all those, as well as following guidelines on the Lightning page will ensure that if you start from checkpoint your training run will produce THE SAME results. This, especially point 3 which was already discussed on PyTorch forum needs fixing.

Hope it helps anyone who’s struggling with this issues and I’m looking forward to Lightning fixing this.