There is a number of issues that I’ve encountered when trying to ensure deterministic results and reproducibility from checkpoint. For example, given a training session that runs for say 10 epochs if I re-run it from one of the epochs the results will differ (train loss). This is despite setting everything as advised on Lightning trainer side in the reproducibility and deterministic section. I went through the code to determine the issues and here are my findings:
-
Seeding random generators (
random
,np.random
,torch.random
) with the same seed such thatseed = x
when resuming from checkpoint is not enough. States of random generators (np.random.get_state()
etc.) change as they’re being used, therefore, if you start an initial training, seed everything withx
and then call e.g.np.random.random()
(or another random generator for that matter) its state will change. Then, if you resume training from checkpoint at epochn
and again seed at the beginning withx
all generators will have THE SAME state as they had at the beginning of the initial training, hence, DIFFERENT than at epochn
from which the training was resumed. My solution to this was to modify this dictionary to include the states of random generator such that:
_checkpoint["random_state"] = random.getstate()
_checkpoint["np_random_state"] = np.random.get_state()
_checkpoint["torch_random_state"] = torch.random.get_rng_state()
Then, whilst restoring the state incheckpoint_connector
L292 inrestore_training_state
afterself.restore_loops
(sorry can’t post more links due to being new user…) I’ve added the relevant code to load those states and it solved it, mostly… Thetorch.random
andrandom.random
seemed to not work as intended (note that the values generated and their states when checkpoint was loaded were exactly the same as states and values generated when saving them during initial training), my assumption is that somewhere in the codebase those generators must be called a different number of times whilst advancing to next epoch vs whilst resuming from checkpoint. Fortunately, numpy seems to work just fine (I guess it’s never used in Lightning/PyTorch) and so my suggestion is to stick tonp.random.random()
to ensure that you ALWAYS get the same values when state is loaded. -
Next, at this point if you are not shuffling your data or setting
num_workers > 0
you should be good to go. Settingshuffle=True
in data loader is another problem. First, if you set it toTrue
thenDataLoader
will useRandomSampler
. All good until you get into__iter__
which calls the followingseed = int(torch.empty((), dtype=torch.int64).random_().item())
if thegenerator=None
. Now, as I’ve mentioned in the first step, restoring state oftorch.random
didn’t help here, as I said I assume it’s state must change differently when advancing from one epoch to another vs when starting from checkpoint and thus that number will not be the same and so will your batch. My initial solution of course was to try to supplygenerator
object and I thought it will solve it which it didn’t. Reason is relatively simple but annoying to find and again the problem is that if your loader is usingRandomSampler
and you give it agenerator
then its state as it gets to__iter__
will be different when advancing to next epoch vs when you starts from checkpoint. And this is due to the fact that when starting training Lightning will callsetup_data
inFitLoop
which will call_check_dataloader_iterable
and this in turn callsiter(dataloader)
. This happens once again I think somewhere inadvance()
so essentially when you advance to next epochiter(dataloder)
is called once but when you start from checkpoint theiter(dataloader)
is called at least twice as it needs to setup the data. Therefore, the state ofgenerator
will be different when you get toRandomSampler.__iter___
. Fixing this was relatively simple as you only need to override theRandomSampler
(don’t pass generator) and set a seed yourself to some value on the line where PyTorch callsseed = int(torch.empty((), dtype=torch.int64).random_().item())
. -
Now, final issue is when you set
num_workers > 0
. Theseed_everything
which is meant to deal with workers is not doing enough. You also need to define yourwork_init_fn
and pass it to data loader as you initialize it. The function needs to re-seed number of things, the fix can be found here How to fix all workers' seed via worker_init_fn for every iter? - #2 by ptrblck - vision - PyTorch Forums.
Doing all those, as well as following guidelines on the Lightning page will ensure that if you start from checkpoint your training run will produce THE SAME results. This, especially point 3 which was already discussed on PyTorch forum needs fixing.
Hope it helps anyone who’s struggling with this issues and I’m looking forward to Lightning fixing this.