Source code for baybe.utils.random
"""Utilities targeting random number generation."""
from __future__ import annotations
import contextlib
import random
import warnings
from typing import TYPE_CHECKING
import numpy as np
from attrs import cmp_using, define, field
from typing_extensions import Self, deprecated
if TYPE_CHECKING:
from torch import Tensor
[docs]
@deprecated(
"Using 'set_random_seed' is deprecated and support will be removed in a future "
"release. Use the new settings management system instead. For details: "
"https://emdgroup.github.io/baybe/stable/userguide/settings.html",
)
def set_random_seed(seed: int):
"""Set the global random seed.
Args:
seed: The chosen global random seed.
"""
import torch
# Ensure seed limits
seed %= 2**32
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
[docs]
@deprecated(
"Using 'temporary_seed' is deprecated and support will be removed in a future "
"release. Use the new settings management system instead. For details: "
"https://emdgroup.github.io/baybe/stable/userguide/settings.html",
)
@contextlib.contextmanager
def temporary_seed(seed: int): # noqa: DOC402, DOC404
"""Context manager for setting a temporary random seed.
Args:
seed: The chosen random seed.
"""
import torch
# Ensure seed limits
seed %= 2**32
# Collect the current RNG states
state_python = random.getstate()
state_np = np.random.get_state()
state_torch = torch.get_rng_state()
# Set the requested seed
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Using 'set_random_seed' is deprecated",
category=DeprecationWarning,
)
set_random_seed(seed)
# Run the context-specific code
try:
yield
# Restore the original RNG states
finally:
random.setstate(state_python)
np.random.set_state(state_np)
torch.set_rng_state(state_torch)
def _lazy_torch_equal(a: Tensor, b: Tensor, /) -> bool:
"""Equality check for tensors with lazy torch import."""
import torch
return torch.equal(a, b)
@define(frozen=True)
class _RandomState:
"""Container for the random states of all managed numeric libraries."""
state_python = field(init=False, factory=random.getstate)
"""The state of the Python random number generator."""
state_numpy = field(
init=False,
factory=np.random.get_state,
eq=cmp_using(
eq=lambda s1, s2: all(np.array_equal(a, b) for a, b in zip(s1, s2))
),
)
"""The state of the Numpy random number generator."""
state_torch: Tensor = field(init=False, eq=cmp_using(eq=_lazy_torch_equal))
"""The state of the Torch random number generator."""
# Note: initialized by attrs default method below (for lazy torch loading)
@state_torch.default
def _default_state_torch(self) -> Tensor:
"""Get the current Torch random state using a lazy import."""
import torch
return torch.get_rng_state()
def activate(self) -> None:
"""Activate the random state."""
import torch
random.setstate(self.state_python)
np.random.set_state(self.state_numpy)
torch.set_rng_state(self.state_torch)
@staticmethod
def _reseed(seed: int) -> None:
"""Seed all random number generators."""
import torch
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
@classmethod
def from_seed(cls, seed: int, *, activate: bool = False) -> Self:
"""Create a random state corresponding to a given seed."""
if activate:
cls._reseed(seed)
return cls()
backup = cls()
cls._reseed(seed)
state = cls()
backup.activate()
return state