Source code for baybe.constraints.continuous

"""Continuous constraints."""

from __future__ import annotations

import gc
import math
from collections.abc import Collection, Iterator, Sequence
from itertools import combinations
from math import comb
from typing import TYPE_CHECKING, Any

import cattrs
import numpy as np
from attrs import define, field
from attrs.validators import deep_iterable, gt, in_, instance_of, lt

from baybe.constraints.base import (
    CardinalityConstraint,
    ContinuousConstraint,
    ContinuousNonlinearConstraint,
)
from baybe.parameters import NumericalContinuousParameter
from baybe.utils.interval import Interval
from baybe.utils.numerical import DTypeFloatNumpy
from baybe.utils.validation import finite_float

if TYPE_CHECKING:
    from torch import Tensor

_valid_linear_constraint_operators = ["=", ">=", "<="]


[docs] @define class ContinuousLinearConstraint(ContinuousConstraint): """Class for continuous linear constraints. Continuous linear constraints use parameter lists and coefficients to define in-/equality constraints over a continuous parameter space. """ # object variables operator: str = field(validator=in_(_valid_linear_constraint_operators)) """Defines the operator used in the equation. Internally this will negate rhs and coefficients for `<=`.""" coefficients: tuple[float, ...] = field( converter=lambda x: cattrs.structure(x, tuple[float, ...]), validator=deep_iterable(member_validator=finite_float), ) """In-/equality coefficient for each entry in ``parameters``.""" rhs: float = field(default=0.0, converter=float, validator=finite_float) """Right-hand side value of the in-/equality.""" is_interpoint: bool = field( alias="interpoint", default=False, validator=instance_of(bool) ) """Flag for defining an interpoint constraint. While intrapoint constraints impose conditions on each individual point of a batch, interpoint constraints do so **across** the points of the batch. That is, an interpoint constraint of the form ``x <= 100`` encodes that the sum of the values of the parameter ``x`` across all points in the batch must be less than or equal to ``100``. """ @coefficients.validator def _validate_coefficients( # noqa: DOC101, DOC103 self, _: Any, coefficients: Sequence[float] ) -> None: """Validate the coefficients. Raises: ValueError: If the number of coefficients does not match the number of parameters. """ if len(self.parameters) != len(coefficients): raise ValueError( "The given 'coefficients' list must have one floating point entry for " "each entry in 'parameters'." ) @coefficients.default def _default_coefficients(self) -> tuple[float, ...]: """Return equal weight coefficients as default.""" return (1.0,) * len(self.parameters) @property def _multiplier(self) -> float: """The internal multiplier for rhs and coefficients.""" return -1.0 if self.operator == "<=" else 1.0 @property def is_eq(self): """Whether this constraint models an equality (assumed inequality otherwise).""" return self.operator == "=" def _drop_parameters( self, parameter_names: Collection[str] ) -> ContinuousLinearConstraint: """Create a copy of the constraint with certain parameters removed. Args: parameter_names: The names of the parameter to be removed. Returns: The reduced constraint. """ parameters = [p for p in self.parameters if p not in parameter_names] coefficients = tuple( c for p, c in zip(self.parameters, self.coefficients, strict=True) if p not in parameter_names ) return ContinuousLinearConstraint( parameters, self.operator, coefficients, self.rhs )
[docs] def to_botorch( self, parameters: Sequence[NumericalContinuousParameter], idx_offset: int = 0, *, batch_size: int | None = None, ) -> tuple[Tensor, Tensor, float]: """Cast the constraint in a format required by botorch. Used in calling ``optimize_acqf_*`` functions, for details see :func:`botorch.optim.optimize.optimize_acqf` Args: parameters: The parameter objects of the continuous space. idx_offset: Offset to the provided parameter indices. batch_size: The batch size used for the recommendation. Necessary for interpoint constraints as indices need to be adjusted. Ignored by all other constraints. Returns: The tuple required by botorch. """ import torch from baybe.utils.torch import DTypeFloatTorch assert not (batch_size is None and self.is_interpoint), ( "No ``batch_size`` set but using interpoint constraints." ) assert not (batch_size is not None and not self.is_interpoint), ( "A ``batch_size`` was set but the constraint is not interpoint." ) param_names = [p.name for p in parameters] # Interpoint and intrapoint require different index formats. For more # information, we refer to the botorch documentation: # https://github.com/pytorch/botorch/blob/1518b304f47f5cdbaf9c175e808c90b3a0a6b86d/botorch/optim/optimize.py#L609 # noqa: E501 param_indices: list[int] | list[tuple[int, int]] coefficients: torch.Tensor if not self.is_interpoint: param_indices = [ param_names.index(p) + idx_offset for p in self.parameters if p in param_names ] coefficients = torch.tensor(self.coefficients, dtype=DTypeFloatTorch) else: assert batch_size is not None param_index_dict = { name: param_names.index(name) for name in self.parameters } param_indices = [ (batch, param_index_dict[name] + idx_offset) for name in self.parameters for batch in range(batch_size) ] coefficients = torch.tensor( self.coefficients, dtype=DTypeFloatTorch ).repeat_interleave(batch_size) return ( torch.tensor(param_indices), self._multiplier * coefficients, np.asarray(self._multiplier * self.rhs, dtype=DTypeFloatNumpy).item(), )
[docs] @define class ContinuousCardinalityConstraint( CardinalityConstraint, ContinuousNonlinearConstraint ): """Class for continuous cardinality constraints.""" relative_threshold: float = field( default=1e-3, converter=float, validator=[gt(0.0), lt(1.0)] ) """A relative threshold for determining if a value is considered zero. The threshold is translated into an asymmetric open interval around zero via :meth:`get_absolute_thresholds`. **Note:** The interval induced by the threshold is considered **open** because numerical routines that optimize parameter values on the complementary set (i.e. the value range considered "nonzero") may push the numerical value exactly to the interval boundary, which should therefore also be considered "nonzero". """ @property def n_inactive_parameter_combinations(self) -> int: """The number of possible inactive parameter combinations.""" return sum( comb(len(self.parameters), n_inactive_parameters) for n_inactive_parameters in self._inactive_set_sizes() ) def _inactive_set_sizes(self) -> range: """Get all possible sizes of inactive parameter sets.""" return range( len(self.parameters) - self.max_cardinality, len(self.parameters) - self.min_cardinality + 1, )
[docs] def inactive_parameter_combinations(self) -> Iterator[frozenset[str]]: """Get an iterator over all possible combinations of inactive parameters.""" for n_inactive_parameters in self._inactive_set_sizes(): yield from combinations(self.parameters, n_inactive_parameters)
[docs] def sample_inactive_parameters(self, batch_size: int = 1) -> list[set[str]]: """Sample sets of inactive parameters according to the cardinality constraints. Args: batch_size: The number of parameter sets to be sampled. Returns: A list of sampled inactive parameter sets, where each set holds the corresponding parameter names. """ # The number of possible parameter configuration per set cardinality n_configurations_per_cardinality = [ math.comb(len(self.parameters), n) for n in range(self.min_cardinality, self.max_cardinality + 1) ] # Probability of each set cardinality under the assumption that all possible # inactive parameter sets are equally likely probabilities = n_configurations_per_cardinality / np.sum( n_configurations_per_cardinality ) # Sample the number of active/inactive parameters n_active_params = np.random.choice( np.arange(self.min_cardinality, self.max_cardinality + 1), batch_size, p=probabilities, ) n_inactive_params = len(self.parameters) - n_active_params # Sample the inactive parameters inactive_params = [ set(np.random.choice(self.parameters, n_inactive, replace=False)) for n_inactive in n_inactive_params ] return inactive_params
[docs] def get_absolute_thresholds(self, bounds: Interval, /) -> Interval: """Get the absolute thresholds for a given interval. Turns the relative threshold of the constraint into absolute thresholds for the considered interval. That is, for a given interval ``(a, b)`` with ``a <= 0`` and ``b >= 0``, the method returns the interval ``(r*a, r*b)``, where ``r`` is the relative threshold defined by the constraint. Args: bounds: The specified interval. Returns: The absolute thresholds represented as an interval. Raises: ValueError: When the specified interval does not contain zero. """ if not bounds.contains(0.0): raise ValueError( f"The specified interval must contain zero. Given: {bounds.to_tuple()}." ) return Interval( lower=self.relative_threshold * bounds.lower, upper=self.relative_threshold * bounds.upper, )
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()