Source code for baybe.parameters.numerical

"""Numerical parameters."""

import gc
from functools import cached_property
from typing import Any, ClassVar

import cattrs
import numpy as np
import pandas as pd
from attrs import define, field
from attrs.validators import min_len
from typing_extensions import override

from baybe.exceptions import NumericalUnderflowError
from baybe.parameters.base import ContinuousParameter, DiscreteParameter
from baybe.parameters.validation import validate_is_finite, validate_unique_values
from baybe.utils.interval import InfiniteIntervalError, Interval
from baybe.utils.numerical import DTypeFloatNumpy


[docs] @define(frozen=True, slots=False) class NumericalDiscreteParameter(DiscreteParameter): """Class for discrete numerical parameters (a.k.a. setpoints).""" # class variables is_numerical: ClassVar[bool] = True # See base class. # object variables # NOTE: The parameter values are assumed to be sorted by the tolerance validator. _values: tuple[float, ...] = field( alias="values", # FIXME[typing]: https://github.com/python-attrs/cattrs/issues/111 converter=lambda x: sorted(cattrs.structure(x, tuple[float, ...])), # type: ignore # FIXME[typing]: https://github.com/python-attrs/attrs/issues/1197 validator=[ min_len(2), validate_unique_values, # type: ignore validate_is_finite, ], ) """The values the parameter can take.""" tolerance: float = field(default=0.0) """The absolute tolerance used for deciding whether a value is in range. A tolerance larger than half the minimum distance between parameter values is not allowed because that could cause ambiguity when inputting data points later.""" @tolerance.validator def _validate_tolerance( # noqa: DOC101, DOC103 self, _: Any, tolerance: float ) -> None: """Validate that the given tolerance is safe. The tolerance is the allowed experimental uncertainty when reading in measured values. A tolerance larger than half the minimum distance between parameter values is not allowed because that could cause ambiguity when inputting data points later. Raises: ValueError: If the tolerance is not safe. """ # For zero tolerance, the only left requirement is that all parameter values # are distinct, which is already ensured by the corresponding validator. if tolerance == 0.0: return min_dist = np.diff(self._values).min() if min_dist == (eps := np.nextafter(0, 1)): raise NumericalUnderflowError( f"The distance between any two parameter values must be at least " f"twice the size of the used floating point resolution of {eps}." ) if tolerance >= (max_tol := min_dist / 2.0): raise ValueError( f"Parameter '{self.name}' is initialized with tolerance {tolerance} " f"but due to the given parameter values {self.values}, the specified " f"tolerance must be smaller than {max_tol} to avoid ambiguity." ) @override @property def values(self) -> tuple: return tuple(DTypeFloatNumpy(itm) for itm in self._values) @override @cached_property def comp_df(self) -> pd.DataFrame: comp_df = pd.DataFrame( {self.name: self.values}, index=self.values, dtype=DTypeFloatNumpy ) return comp_df
[docs] @override def is_in_range(self, item: float) -> bool: return any( Interval(val - self.tolerance, val + self.tolerance).contains(item) for val in self.values )
[docs] @define(frozen=True, slots=False) class NumericalContinuousParameter(ContinuousParameter): """Class for continuous numerical parameters.""" is_numerical: ClassVar[bool] = True # See base class. # TODO[typing]: https://github.com/python-attrs/attrs/issues/1435 bounds: Interval = field(default=None, converter=Interval.create) # type: ignore[misc] """The bounds of the parameter.""" @bounds.validator def _validate_bounds(self, _: Any, value: Interval) -> None: # noqa: DOC101, DOC103 """Validate bounds. Raises: InfiniteIntervalError: If the provided interval is not finite. """ if not value.is_bounded: raise InfiniteIntervalError( f"You are trying to initialize a parameter with an infinite range " f"of {value.to_tuple()}. Infinite intervals for parameters are " f"currently not supported." ) if value.is_degenerate: raise ValueError( "The interval specified by the parameter bounds cannot be degenerate." )
[docs] @override def is_in_range(self, item: float) -> bool: return self.bounds.contains(item)
@override @property def comp_rep_columns(self) -> tuple[str]: return (self.name,)
[docs] @override def summary(self) -> dict: param_dict = dict( Name=self.name, Type=self.__class__.__name__, Lower_Bound=self.bounds.lower, Upper_Bound=self.bounds.upper, ) return param_dict
@define(frozen=True, slots=False) class _FixedNumericalContinuousParameter(ContinuousParameter): """Parameter class for fixed numerical parameters.""" is_numeric: ClassVar[bool] = True # See base class. value: float = field(converter=float) """The fixed value of the parameter.""" @property def bounds(self) -> Interval: """The value of the parameter as a degenerate interval.""" return Interval(self.value, self.value) @override def is_in_range(self, item: float) -> bool: return Interval(self.value, self.value).contains(item) @override @property def comp_rep_columns(self) -> tuple[str]: return (self.name,) @override def summary(self) -> dict: return dict( Name=self.name, Type=self.__class__.__name__, Value=self.value, ) # Collect leftover original slotted classes processed by `attrs.define` gc.collect()