"""BayBE settings."""
from __future__ import annotations
import gc
import os
import tempfile
import warnings
from copy import deepcopy
from functools import wraps
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, TypeVar
import numpy as np
from attrs import Attribute, Converter, Factory, define, field, fields
from attrs.setters import validate
from attrs.validators import instance_of
from attrs.validators import optional as optional_v
from baybe._optional.info import FPSAMPLE_INSTALLED, POLARS_INSTALLED
from baybe.exceptions import NotAllowedError, OptionalImportError
from baybe.utils.basic import classproperty
from baybe.utils.boolean import AutoBool, to_bool
from baybe.utils.random import _RandomState
if TYPE_CHECKING:
import torch
_TSeed = TypeVar("_TSeed", int, None)
_RANDOM_SEED_ATTRIBUTE_NAME = "random_seed"
# The temporary assignment to `None` is needed because the object is already referenced
# in the `Settings` class body
active_settings: Settings = None # type: ignore[assignment]
"""The active settings instance controlling execution behavior."""
_MISSING_PACKAGE_ERROR_MESSAGE = (
"The setting 'use_{package_name}' cannot be set to 'True' because '{package_name}' "
"is not installed. Either install '{package_name}' or set 'use_{package_name}' "
"to 'False'/'Auto'."
)
def _validate_whitelist_env_vars(vars: dict[str, str], /) -> None:
"""Validate the values of non-settings environment variables."""
if (value := vars.pop("BAYBE_TEST_ENV", None)) is not None:
if value not in {"CORETEST", "FULLTEST", "GPUTEST"}:
raise ValueError(
f"Allowed values for 'BAYBE_TEST_ENV' are "
f"'CORETEST', 'FULLTEST', and 'GPUTEST'. Given: '{value}'"
)
if vars:
raise RuntimeError(f"Unknown 'BAYBE_*' environment variables: {set(vars)}")
class _SlottedContextDecorator:
"""Like :class:`contextlib.ContextDecorator` but with `__slots__`.
The code has been copied from the Python standard library.
"""
__slots__ = ()
def _recreate_cm(self):
return self
def __call__(self, func):
@wraps(func)
def inner(*args, **kwds):
with self._recreate_cm():
return func(*args, **kwds)
return inner
[docs]
def adjust_defaults(cls: type[Settings], fields: list[Attribute]) -> list[Attribute]:
"""Replace default values with the appropriate source, controlled via flags."""
results = []
for fld in fields:
if fld.name in (*cls._non_settings_names, _RANDOM_SEED_ATTRIBUTE_NAME):
results.append(fld)
continue
# We use a factory here because the environment variables should be looked up
# at instantiation time, not at class definition time
def make_default_factory(fld: Attribute) -> Any:
# TODO: https://github.com/python-attrs/attrs/issues/1479
name = fld.alias or fld.name
def get_default_value(self: Settings) -> Any:
"""Dynamically retrieve the default value for the field.
Depending on the control flags, the value is retrieved either from the
field specification itself, from the corresponding environment variable,
or from the active settings object.
"""
if self._restore_defaults:
default = fld.default
else:
# Here, the active settings value is used as default, to
# enable updating settings one attribute at a time (the fallback to
# the default happens when the active settings object is itself
# being created)
default = getattr(active_settings, fld.name, fld.default)
if self._restore_environment:
# If enabled, the environment values take precedence for the default
env_name = f"BAYBE_{name.upper()}"
value = os.getenv(env_name, default)
if fld.type == "bool":
value = to_bool(value)
return value
return default
return Factory(get_default_value, takes_self=True)
results.append(fld.evolve(default=make_default_factory(fld)))
return results
def _on_set_random_seed(instance: Settings, __: Attribute, value: _TSeed) -> _TSeed:
"""Activate the random seed when changing the attribute of the active settings."""
if id(instance) == Settings._active_settings_id and value is not None:
_RandomState.from_seed(value, activate=True)
return value
def _convert_cache_directory(
value: str | Path | None, field: Attribute, /
) -> Path | None:
"""Attrs converter for the cache directory setting."""
if value is None or value == "":
return None
try:
return Path(value)
except Exception as ex:
raise type(ex)(
f"Cannot set '{field.alias}' to '{value}'. "
f"Expected 'None' or a path-like object."
) from ex
[docs]
@define(kw_only=True, field_transformer=adjust_defaults)
class Settings(_SlottedContextDecorator):
"""BayBE settings."""
### Internal
_active_settings_id: ClassVar[int]
"""The id of the active settings instance.
Useful to identify if an action is performed on the active or a local instance."""
_previous_settings: Settings | None = field(default=None, init=False)
"""The previously active settings (used for context management)."""
_previous_random_state: _RandomState | None = field(default=None, init=False)
"""The previous random state (used for context management)."""
### Control flags
_restore_defaults: bool = field(default=False, validator=instance_of(bool))
"""Controls if settings shall be restored to their default values."""
_restore_environment: bool = field(default=False, validator=instance_of(bool))
"""Controls if environment variables shall be used to initialize settings."""
### Settings attributes
cache_campaign_recommendations: bool = field(
default=True, validator=instance_of(bool)
)
"""Controls if :class:`~baybe.campaign.Campaign` objects cache their latest set of
recommendations."""
cache_directory: Path | None = field(
default=Path(tempfile.gettempdir()) / ".baybe_cache",
converter=Converter(_convert_cache_directory, takes_field=True), # type: ignore[misc]
)
"""The directory used for persistent caching on disk. Set to ``""`` or ``None`` to disable caching.""" # noqa: E501
parallelize_simulation_runs: bool = field(default=True, validator=instance_of(bool))
"""Controls if simulation runs with `xyzpy <https://xyzpy.readthedocs.io/>`_ are executed in parallel.""" # noqa: E501
preprocess_dataframes: bool = field(default=True, validator=instance_of(bool))
"""Controls if incoming user dataframes are preprocessed (i.e., dtype-converted and validated) before use.""" # noqa: E501
random_seed: int | None = field(
default=None,
validator=optional_v(instance_of(int)),
on_setattr=[validate, _on_set_random_seed],
)
"""The used random seed."""
_use_fpsample: AutoBool = field(
alias="use_fpsample",
default=AutoBool.AUTO,
converter=AutoBool.from_unstructured, # type: ignore[misc]
)
"""Controls if `fpsample <https://github.com/leonardodalinky/fpsample>`_ acceleration is to be used, if available.""" # noqa: E501
_use_polars_for_constraints: AutoBool = field(
alias="use_polars_for_constraints",
default=AutoBool.AUTO,
converter=AutoBool.from_unstructured, # type: ignore[misc]
)
"""Controls if `polars <https://pola.rs/>`_ acceleration is to be used for constraints, if available.""" # noqa: E501
use_single_precision_numpy: bool = field(default=False, validator=instance_of(bool))
"""Controls the floating point precision used for `numpy <https://numpy.org/>`_ arrays.""" # noqa: E501
use_single_precision_torch: bool = field(default=False, validator=instance_of(bool))
"""Controls the floating point precision used for `torch <https://pytorch.org/>`_ tensors.""" # noqa: E501
def __attrs_pre_init__(self) -> None:
# >>>>> Deprecation
flds = fields(Settings)
pairs: list[tuple[str, Attribute]] = [
("BAYBE_NUMPY_USE_SINGLE_PRECISION", flds.use_single_precision_numpy),
("BAYBE_TORCH_USE_SINGLE_PRECISION", flds.use_single_precision_torch),
("BAYBE_DEACTIVATE_POLARS", flds._use_polars_for_constraints),
("BAYBE_PARALLEL_SIMULATION_RUNS", flds.parallelize_simulation_runs),
("BAYBE_CACHE_DIR", flds.cache_directory),
]
for env_var, fld in pairs:
if (value := os.environ.pop(env_var, None)) is not None:
warnings.warn(
f"The environment variable '{env_var}' has "
f"been deprecated and support will be dropped in a future version. "
f"Please use 'BAYBE_{(fld.alias or fld.name).upper()}' instead. "
f"For now, we've automatically handled the translation for you.",
DeprecationWarning,
)
if env_var.endswith("POLARS"):
value = "false" if to_bool(value) else "true"
elif env_var.endswith("SIMULATION_RUNS"):
value = "true" if to_bool(value) else "false"
os.environ[f"BAYBE_{(fld.alias or fld.name).upper()}"] = value
# <<<<< Deprecation
known_env_vars = {
f"BAYBE_{attr.alias.upper()}" for attr in self._settings_attributes
}
_validate_whitelist_env_vars(
{
k: v
for k, v in os.environ.items()
if k.startswith("BAYBE_") and k not in known_env_vars
}
)
def __enter__(self) -> Settings:
self.activate()
return self
def __exit__(self, *args) -> None:
self.restore_previous()
@_use_polars_for_constraints.validator
def _validate_use_polars_for_constraints(self, _, value: AutoBool) -> None:
if value is AutoBool.TRUE and not POLARS_INSTALLED:
raise OptionalImportError(
_MISSING_PACKAGE_ERROR_MESSAGE.format(package_name="polars")
)
@_use_fpsample.validator
def _validate_use_fpsample(self, _, value: AutoBool) -> None:
if value is AutoBool.TRUE and not FPSAMPLE_INSTALLED:
raise OptionalImportError(
_MISSING_PACKAGE_ERROR_MESSAGE.format(package_name="fpsample")
)
@property
def use_polars_for_constraints(self) -> bool:
"""Indicates if ``polars`` is enabled (i.e., installed and set to be used)."""
return self._use_polars_for_constraints.evaluate(lambda: POLARS_INSTALLED)
@use_polars_for_constraints.setter
def use_polars_for_constraints(self, value: AutoBool | bool, /) -> None:
# Note: uses attrs converter
self._use_polars_for_constraints = value # type: ignore[assignment]
@property
def use_fpsample(self) -> bool:
"""Indicates if ``fpsample`` is enabled (i.e., installed and set to be used).""" # noqa: E501
return self._use_fpsample.evaluate(lambda: FPSAMPLE_INSTALLED)
@use_fpsample.setter
def use_fpsample(self, value: AutoBool | bool, /) -> None:
# Note: uses attrs converter
self._use_fpsample = value # type: ignore[assignment]
@property
def DTypeFloatNumpy(self) -> type[np.floating]:
"""The floating point precision used for ``numpy`` arrays."""
return np.float32 if self.use_single_precision_numpy else np.float64
@property
def DTypeFloatTorch(self) -> torch.dtype:
"""The floating point precision used for ``torch`` tensors."""
import torch
return torch.float32 if self.use_single_precision_torch else torch.float64
@classproperty
def _non_settings_names(cls) -> frozenset[str]:
"""The names of attributes that do not represent user-facing settings.""" # noqa: D401
# IMPROVE: This approach is not type-safe but the set is needed already at
# class definition time, which means we cannot use `attrs.fields` or similar.
# Perhaps `typing.Annotated` can be used, if there's an elegant way to
# resolve the stringified types coming from `__future__.annotations`?
return frozenset(
{
"_previous_settings",
"_previous_random_state",
"_restore_defaults",
"_restore_environment",
}
)
@classproperty
def _settings_attributes(cls) -> tuple[Attribute, ...]:
"""The attributes representing the available user-facing settings.""" # noqa: D401
return tuple(
fld
for fld in fields(Settings)
if fld.name not in Settings._non_settings_names
)
[docs]
def activate(self) -> Settings:
"""Activate the settings globally."""
if id(self) == Settings._active_settings_id:
raise NotAllowedError(
f"Calling '{self.activate.__name__}' on the active settings "
f"object is not allowed since it is already active."
)
# Store the previous state only if it's actually required for settings
# restoration later on (see `restore_previous` method)
if self.random_seed is not None:
self._previous_random_state = _RandomState()
self._previous_settings = deepcopy(active_settings)
self.overwrite(active_settings)
return self
[docs]
def restore_previous(self) -> None:
"""Restore the previous settings."""
if id(self) == Settings._active_settings_id:
raise NotAllowedError(
f"Calling '{self.restore_previous.__name__}' on the active settings "
f"object is not supported."
)
if self._previous_settings is None:
raise RuntimeError(
"The settings have not yet been activated, "
"so there are no previous settings to restore."
)
# When restoring, we do not want to re-sync the random state back to
# the seed value of the previous setting, since the random state has
# potentially progressed in the meantime ...
self._previous_settings.overwrite(active_settings, keep_random_state=True)
# ... Instead, we restore the random state from setting activation time, but
# only when randomness control was actually part of the settings configurations
# and the state was altered in the first place.
if self.random_seed is not None:
assert self._previous_random_state is not None
self._previous_random_state.activate()
self._previous_random_state = None
# Clear backup attribute
self._previous_settings = None
[docs]
def overwrite(self, target: Settings, keep_random_state: bool = False) -> None:
"""Overwrite the settings of another :class:`Settings` object."""
if keep_random_state:
state = _RandomState()
for fld in self._settings_attributes:
setattr(target, fld.name, getattr(self, fld.name))
if keep_random_state:
state.activate()
# Since there is critical code hardcoded against the attribute name, we
# ensure that the attribute exists as a sanity check (in case of future name edits)
assert _RANDOM_SEED_ATTRIBUTE_NAME in (fld.name for fld in fields(Settings))
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()
active_settings = Settings(restore_environment=True)
"""The current active settings."""
# Set the global settings id for later reference
Settings._active_settings_id = id(active_settings)
# Special handling of the random seed:
# The automatic adoption of seed values from the environment or the active settings
# object as default value for new settings objects is skipped in the class logic to
# enable proper progression of random states. However, we still want that a given
# environmental seed populates the active settings object (and *only* that object) upon
# session start, so we manually set it here.
if (
_seed := os.environ.get(f"BAYBE_{_RANDOM_SEED_ATTRIBUTE_NAME.upper()}", None)
) is not None:
active_settings.random_seed = int(_seed)