# Copyright The PyTorch Lightning team.## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## http://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Utilities to help with reproducibility of models."""importloggingimportosimportrandomfromcontextlibimportcontextmanagerfromrandomimportgetstateaspython_get_rng_statefromrandomimportsetstateaspython_set_rng_statefromtypingimportAny,Dict,Generator,Optionalimportnumpyasnpimporttorchfrompytorch_lightning.utilities.rank_zeroimportrank_zero_only,rank_zero_warnlog=logging.getLogger(__name__)
[docs]defseed_everything(seed:Optional[int]=None,workers:bool=False)->int:"""Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets the following environment variables: - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. Args: seed: the integer value seed for global random state in Lightning. If `None`, will read seed from `PL_GLOBAL_SEED` env variable or select it randomly. workers: if set to ``True``, will properly configure all dataloaders passed to the Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. """max_seed_value=np.iinfo(np.uint32).maxmin_seed_value=np.iinfo(np.uint32).minifseedisNone:env_seed=os.environ.get("PL_GLOBAL_SEED")ifenv_seedisNone:seed=_select_seed_randomly(min_seed_value,max_seed_value)rank_zero_warn(f"No seed found, seed set to {seed}")else:try:seed=int(env_seed)exceptValueError:seed=_select_seed_randomly(min_seed_value,max_seed_value)rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}")elifnotisinstance(seed,int):seed=int(seed)ifnot(min_seed_value<=seed<=max_seed_value):rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}")seed=_select_seed_randomly(min_seed_value,max_seed_value)# using `log.info` instead of `rank_zero_info`,# so users can verify the seed is properly set in distributed training.log.info(f"Global seed set to {seed}")os.environ["PL_GLOBAL_SEED"]=str(seed)random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)os.environ["PL_SEED_WORKERS"]=f"{int(workers)}"returnseed
[docs]defreset_seed()->None:"""Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set. If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """seed=os.environ.get("PL_GLOBAL_SEED",None)ifseedisNone:returnworkers=os.environ.get("PL_SEED_WORKERS","0")seed_everything(int(seed),workers=bool(int(workers)))
[docs]defpl_worker_init_function(worker_id:int,rank:Optional[int]=None)->None:# pragma: no cover"""The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. """# implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562global_rank=rankifrankisnotNoneelserank_zero_only.rankprocess_seed=torch.initial_seed()# back out the base seed so we can use all the bitsbase_seed=process_seed-worker_idlog.debug(f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}")ss=np.random.SeedSequence([base_seed,worker_id,global_rank])# use 128 bits (4 x 32-bit words)np.random.seed(ss.generate_state(4))# Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random moduletorch_ss,stdlib_ss=ss.spawn(2)torch.manual_seed(torch_ss.generate_state(1,dtype=np.uint64)[0])# use 128 bits expressed as an integerstdlib_seed=(stdlib_ss.generate_state(2,dtype=np.uint64).astype(object)*[1<<64,1]).sum()random.seed(stdlib_seed)
def_collect_rng_states()->Dict[str,Any]:"""Collect the global random state of :mod:`torch`, :mod:`numpy` and Python."""return{"torch":torch.get_rng_state(),"numpy":np.random.get_state(),"python":python_get_rng_state()}def_set_rng_states(rng_state_dict:Dict[str,Any])->None:"""Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process."""torch.set_rng_state(rng_state_dict["torch"])np.random.set_state(rng_state_dict["numpy"])version,state,gauss=rng_state_dict["python"]python_set_rng_state((version,tuple(state),gauss))
[docs]@contextmanagerdefisolate_rng()->Generator[None,None,None]:"""A context manager that resets the global random state on exit to what it was before entering. It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. Example: >>> torch.manual_seed(1) # doctest: +ELLIPSIS <torch._C.Generator object at ...> >>> with isolate_rng(): ... [torch.rand(1) for _ in range(3)] [tensor([0.7576]), tensor([0.2793]), tensor([0.4031])] >>> torch.rand(1) tensor([0.7576]) """states=_collect_rng_states()yield_set_rng_states(states)
To analyze traffic and optimize your experience, we serve cookies on this site. By clicking or navigating, you agree to allow our usage of cookies. Read PyTorch Lightning's Privacy Policy.