Source code for baybe.constraints.base

"""Base classes for all constraints."""

from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Collection, Sequence
from typing import TYPE_CHECKING, Any, ClassVar

import numpy as np
import pandas as pd
from attr import define, field
from attr.validators import ge, instance_of, min_len

from baybe.parameters import NumericalContinuousParameter
from baybe.serialization import (
    SerialMixin,
    converter,
    get_base_structure_hook,
    unstructure_base,
)
from baybe.utils.numerical import DTypeFloatNumpy

if TYPE_CHECKING:
    import polars as pl
    from torch import Tensor


[docs] @define class Constraint(ABC, SerialMixin): """Abstract base class for all constraints.""" # class variables # TODO: it might turn out these are not needed at a later development stage eval_during_creation: ClassVar[bool] """Class variable encoding whether the condition is evaluated during creation.""" eval_during_modeling: ClassVar[bool] """Class variable encoding whether the condition is evaluated during modeling.""" eval_during_augmentation: ClassVar[bool] = False """Class variable encoding whether the constraint could be considered during data augmentation.""" numerical_only: ClassVar[bool] = False """Class variable encoding whether the constraint is valid only for numerical parameters.""" # Object variables parameters: list[str] = field(validator=min_len(1)) """The list of parameters used for the constraint.""" @parameters.validator def _validate_params( # noqa: DOC101, DOC103 self, _: Any, params: list[str] ) -> None: """Validate the parameter list. Raises: ValueError: If ``params`` contains duplicate values. """ if len(params) != len(set(params)): raise ValueError( f"The given 'parameters' list must have unique values " f"but was: {params}." )
[docs] def summary(self) -> dict: """Return a custom summarization of the constraint.""" constr_dict = dict( Type=self.__class__.__name__, Affected_Parameters=self.parameters ) return constr_dict
@property def is_continuous(self) -> bool: """Boolean indicating if this is a constraint over continuous parameters.""" return isinstance(self, ContinuousConstraint) @property def is_discrete(self) -> bool: """Boolean indicating if this is a constraint over discrete parameters.""" return isinstance(self, DiscreteConstraint)
[docs] @define class DiscreteConstraint(Constraint, ABC): """Abstract base class for discrete constraints. Discrete constraints use conditions and chain them together to filter unwanted entries from the search space. """ # class variables eval_during_creation: ClassVar[bool] = True # See base class. eval_during_modeling: ClassVar[bool] = False # See base class.
[docs] @abstractmethod def get_invalid(self, data: pd.DataFrame) -> pd.Index: """Get the indices of dataframe entries that are invalid under the constraint. Args: data: A dataframe where each row represents a particular parameter combination. Returns: The dataframe indices of rows where the constraint is violated. """
[docs] def get_invalid_polars(self) -> pl.Expr: """Translate the constraint to Polars expression identifying undesired rows. Returns: The Polars expressions to pass to :meth:`polars.LazyFrame.filter`. Raises: NotImplementedError: If the constraint class does not have a Polars implementation. """ raise NotImplementedError( f"'{self.__class__.__name__}' does not have a Polars implementation." )
[docs] @define class ContinuousConstraint(Constraint, ABC): """Abstract base class for continuous constraints.""" # class variables eval_during_creation: ClassVar[bool] = False # See base class. eval_during_modeling: ClassVar[bool] = True # See base class. numerical_only: ClassVar[bool] = True
# See base class.
[docs] @define class CardinalityConstraint(Constraint, ABC): """Abstract base class for cardinality constraints. Places a constraint on the set of nonzero (i.e. "active") values among the specified parameters, bounding it between the two given integers, ``min_cardinality`` <= |{p_i : p_i != 0}| <= ``max_cardinality`` where ``{p_i}`` are the parameters specified for the constraint. Note that this can be equivalently regarded as L0-constraint on the vector containing the specified parameters. """ # class variable numerical_only: ClassVar[bool] = True # See base class. # object variables min_cardinality: int = field(default=0, validator=[instance_of(int), ge(0)]) "The minimum required cardinality." max_cardinality: int = field(validator=instance_of(int)) "The maximum allowed cardinality." @max_cardinality.default def _default_max_cardinality(self): """Use the number of involved parameters as the upper limit by default.""" return len(self.parameters) def __attrs_post_init__(self): """Validate the cardinality bounds. Raises: ValueError: If the provided cardinality bounds are invalid. ValueError: If the provided cardinality bounds impose no constraint. """ if self.min_cardinality > self.max_cardinality: raise ValueError( f"The lower cardinality bound cannot be larger than the upper bound. " f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}." ) if self.max_cardinality > len(self.parameters): raise ValueError( f"The cardinality bound cannot exceed the number of parameters. " f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}." ) if self.min_cardinality == 0 and self.max_cardinality == len(self.parameters): raise ValueError( f"No constraint of type `{self.__class__.__name__}' is required " f"when the lower cardinality bound is zero and the upper bound equals " f"the number of parameters. Provided values: {self.min_cardinality=}, " f"{self.max_cardinality=}, {len(self.parameters)=}" )
[docs] @define class ContinuousLinearConstraint(ContinuousConstraint, ABC): """Abstract base class for continuous linear constraints. Continuous linear constraints use parameter lists and coefficients to define in-/equality constraints over a continuous parameter space. """ # object variables coefficients: list[float] = field() """In-/equality coefficient for each entry in ``parameters``.""" rhs: float = field(default=0.0) """Right-hand side value of the in-/equality.""" @coefficients.validator def _validate_coefficients( # noqa: DOC101, DOC103 self, _: Any, coefficients: list[float] ) -> None: """Validate the coefficients. Raises: ValueError: If the number of coefficients does not match the number of parameters. """ if len(self.parameters) != len(coefficients): raise ValueError( "The given 'coefficients' list must have one floating point entry for " "each entry in 'parameters'." ) @coefficients.default def _default_coefficients(self): """Return equal weight coefficients as default.""" return [1.0] * len(self.parameters) def _drop_parameters( self, parameter_names: Collection[str] ) -> ContinuousLinearConstraint: """Create a copy of the constraint with certain parameters removed. Args: parameter_names: The names of the parameter to be removed. Returns: The reduced constraint. """ parameters = [p for p in self.parameters if p not in parameter_names] coefficients = [ c for p, c in zip(self.parameters, self.coefficients) if p not in parameter_names ] return ContinuousLinearConstraint(parameters, coefficients, self.rhs)
[docs] def to_botorch( self, parameters: Sequence[NumericalContinuousParameter], idx_offset: int = 0 ) -> tuple[Tensor, Tensor, float]: """Cast the constraint in a format required by botorch. Used in calling ``optimize_acqf_*`` functions, for details see https://botorch.org/api/optim.html#botorch.optim.optimize.optimize_acqf Args: parameters: The parameter objects of the continuous space. idx_offset: Offset to the provided parameter indices. Returns: The tuple required by botorch. """ import torch from baybe.utils.torch import DTypeFloatTorch param_names = [p.name for p in parameters] param_indices = [ param_names.index(p) + idx_offset for p in self.parameters if p in param_names ] return ( torch.tensor(param_indices), torch.tensor(self.coefficients, dtype=DTypeFloatTorch), np.asarray(self.rhs, dtype=DTypeFloatNumpy).item(), )
[docs] class ContinuousNonlinearConstraint(ContinuousConstraint, ABC): """Abstract base class for continuous nonlinear constraints."""
# Register (un-)structure hooks converter.register_unstructure_hook(Constraint, unstructure_base) converter.register_structure_hook(Constraint, get_base_structure_hook(Constraint))