Shortcuts

Source code for pytorch_lightning.utilities.seed

# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities to help with reproducibility of models."""

import logging
import os
import random
from contextlib import contextmanager
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Dict, Generator, Optional

import numpy as np
import torch

from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

log = logging.getLogger(__name__)

max_seed_value = np.iinfo(np.uint32).max
min_seed_value = np.iinfo(np.uint32).min


[docs]def seed_everything(seed: Optional[int] = None, workers: bool = False) -> int: """Function that sets seed for pseudo-random number generators in: pytorch, numpy, python.random In addition, sets the following environment variables: - `PL_GLOBAL_SEED`: will be passed to spawned subprocesses (e.g. ddp_spawn backend). - `PL_SEED_WORKERS`: (optional) is set to 1 if ``workers=True``. Args: seed: the integer value seed for global random state in Lightning. If `None`, will read seed from `PL_GLOBAL_SEED` env variable or select it randomly. workers: if set to ``True``, will properly configure all dataloaders passed to the Trainer with a ``worker_init_fn``. If the user already provides such a function for their dataloaders, setting this argument will have no influence. See also: :func:`~pytorch_lightning.utilities.seed.pl_worker_init_function`. """ if seed is None: env_seed = os.environ.get("PL_GLOBAL_SEED") if env_seed is None: seed = _select_seed_randomly(min_seed_value, max_seed_value) rank_zero_warn(f"No seed found, seed set to {seed}") else: try: seed = int(env_seed) except ValueError: seed = _select_seed_randomly(min_seed_value, max_seed_value) rank_zero_warn(f"Invalid seed found: {repr(env_seed)}, seed set to {seed}") elif not isinstance(seed, int): seed = int(seed) if not (min_seed_value <= seed <= max_seed_value): rank_zero_warn(f"{seed} is not in bounds, numpy accepts from {min_seed_value} to {max_seed_value}") seed = _select_seed_randomly(min_seed_value, max_seed_value) # using `log.info` instead of `rank_zero_info`, # so users can verify the seed is properly set in distributed training. log.info(f"Global seed set to {seed}") os.environ["PL_GLOBAL_SEED"] = str(seed) random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ["PL_SEED_WORKERS"] = f"{int(workers)}" return seed
def _select_seed_randomly(min_seed_value: int = min_seed_value, max_seed_value: int = max_seed_value) -> int: return random.randint(min_seed_value, max_seed_value)
[docs]def reset_seed() -> None: """Reset the seed to the value that :func:`pytorch_lightning.utilities.seed.seed_everything` previously set. If :func:`pytorch_lightning.utilities.seed.seed_everything` is unused, this function will do nothing. """ seed = os.environ.get("PL_GLOBAL_SEED", None) if seed is None: return workers = os.environ.get("PL_SEED_WORKERS", "0") seed_everything(int(seed), workers=bool(int(workers)))
[docs]def pl_worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on `randomness in DataLoaders <https://pytorch.org/docs/stable/notes/randomness.html#dataloader>`_. """ # implementation notes: https://github.com/pytorch/pytorch/issues/5059#issuecomment-817392562 global_rank = rank if rank is not None else rank_zero_only.rank process_seed = torch.initial_seed() # back out the base seed so we can use all the bits base_seed = process_seed - worker_id log.debug( f"Initializing random number generators of process {global_rank} worker {worker_id} with base seed {base_seed}" ) ss = np.random.SeedSequence([base_seed, worker_id, global_rank]) # use 128 bits (4 x 32-bit words) np.random.seed(ss.generate_state(4)) # Spawn distinct SeedSequences for the PyTorch PRNG and the stdlib random module torch_ss, stdlib_ss = ss.spawn(2) torch.manual_seed(torch_ss.generate_state(1, dtype=np.uint64)[0]) # use 128 bits expressed as an integer stdlib_seed = (stdlib_ss.generate_state(2, dtype=np.uint64).astype(object) * [1 << 64, 1]).sum() random.seed(stdlib_seed)
def _collect_rng_states() -> Dict[str, Any]: """Collect the global random state of :mod:`torch`, :mod:`numpy` and Python.""" return {"torch": torch.get_rng_state(), "numpy": np.random.get_state(), "python": python_get_rng_state()} def _set_rng_states(rng_state_dict: Dict[str, Any]) -> None: """Set the global random state of :mod:`torch`, :mod:`numpy` and Python in the current process.""" torch.set_rng_state(rng_state_dict["torch"]) np.random.set_state(rng_state_dict["numpy"]) version, state, gauss = rng_state_dict["python"] python_set_rng_state((version, tuple(state), gauss))
[docs]@contextmanager def isolate_rng() -> Generator[None, None, None]: """A context manager that resets the global random state on exit to what it was before entering. It supports isolating the states for PyTorch, Numpy, and Python built-in random number generators. Example: >>> torch.manual_seed(1) # doctest: +ELLIPSIS <torch._C.Generator object at ...> >>> with isolate_rng(): ... [torch.rand(1) for _ in range(3)] [tensor([0.7576]), tensor([0.2793]), tensor([0.4031])] >>> torch.rand(1) tensor([0.7576]) """ states = _collect_rng_states() yield _set_rng_states(states)

© Copyright Copyright (c) 2018-2023, Lightning AI et al...

Built with Sphinx using a theme provided by Read the Docs.