"""Base classes for all constraints."""from__future__importannotationsimportgcfromabcimportABC,abstractmethodfromtypingimportTYPE_CHECKING,Any,ClassVarimportpandasaspdfromattrsimportdefine,fieldfromattrs.validatorsimportge,instance_of,min_lenfrombaybe.constraints.deprecationimport(ContinuousLinearEqualityConstraint,ContinuousLinearInequalityConstraint,)frombaybe.serializationimport(SerialMixin,)frombaybe.serialization.coreimport(converter,)ifTYPE_CHECKING:importpolarsaspl
[docs]@defineclassConstraint(ABC,SerialMixin):"""Abstract base class for all constraints."""# class variables# TODO: it might turn out these are not needed at a later development stageeval_during_creation:ClassVar[bool]"""Class variable encoding whether the condition is evaluated during creation."""eval_during_modeling:ClassVar[bool]"""Class variable encoding whether the condition is evaluated during modeling."""eval_during_augmentation:ClassVar[bool]=False"""Class variable encoding whether the constraint could be considered during data augmentation."""numerical_only:ClassVar[bool]=False"""Class variable encoding whether the constraint is valid only for numerical parameters."""# Object variablesparameters:list[str]=field(validator=min_len(1))"""The list of parameters used for the constraint."""@parameters.validatordef_validate_params(# noqa: DOC101, DOC103self,_:Any,params:list[str])->None:"""Validate the parameter list. Raises: ValueError: If ``params`` contains duplicate values. """iflen(params)!=len(set(params)):raiseValueError(f"The given 'parameters' list must have unique values "f"but was: {params}.")
[docs]defsummary(self)->dict:"""Return a custom summarization of the constraint."""constr_dict=dict(Type=self.__class__.__name__,Affected_Parameters=self.parameters)returnconstr_dict
@propertydefis_continuous(self)->bool:"""Boolean indicating if this is a constraint over continuous parameters."""returnisinstance(self,ContinuousConstraint)@propertydefis_discrete(self)->bool:"""Boolean indicating if this is a constraint over discrete parameters."""returnisinstance(self,DiscreteConstraint)
[docs]@defineclassDiscreteConstraint(Constraint,ABC):"""Abstract base class for discrete constraints. Discrete constraints use conditions and chain them together to filter unwanted entries from the search space. """# class variableseval_during_creation:ClassVar[bool]=True# See base class.eval_during_modeling:ClassVar[bool]=False# See base class.
[docs]defget_valid(self,df:pd.DataFrame,/)->pd.Index:"""Get the indices of dataframe entries that are valid under the constraint. Args: df: A dataframe where each row represents a parameter configuration. Returns: The dataframe indices of rows that fulfill the constraint. """invalid=self.get_invalid(df)returndf.index.drop(invalid)
[docs]@abstractmethoddefget_invalid(self,data:pd.DataFrame)->pd.Index:"""Get the indices of dataframe entries that are invalid under the constraint. Args: data: A dataframe where each row represents a parameter configuration. Returns: The dataframe indices of rows that violate the constraint. """
# TODO: Should switch backends (pandas/polars/...) behind the scenes
[docs]defget_invalid_polars(self)->pl.Expr:"""Translate the constraint to Polars expression identifying undesired rows. Returns: The Polars expressions to pass to :meth:`polars.LazyFrame.filter`. Raises: NotImplementedError: If the constraint class does not have a Polars implementation. """raiseNotImplementedError(f"'{self.__class__.__name__}' does not have a Polars implementation.")
[docs]@defineclassContinuousConstraint(Constraint,ABC):"""Abstract base class for continuous constraints."""# class variableseval_during_creation:ClassVar[bool]=False# See base class.eval_during_modeling:ClassVar[bool]=True# See base class.numerical_only:ClassVar[bool]=True
# See base class.
[docs]@defineclassCardinalityConstraint(Constraint,ABC):r"""Abstract base class for cardinality constraints. Places a constraint on the set of nonzero (i.e. "active") values among the specified parameters, bounding it between the two given integers, i.e. .. math:: \text{min_cardinality} \leq |\{p_i : p_i \neq 0\}| \leq \text{max_cardinality} where :math:`\{p_i\}` are the parameters specified for the constraint. Note that this can be equivalently regarded as L0-constraint on the vector containing the specified parameters. """# class variablenumerical_only:ClassVar[bool]=True# See base class.# object variablesmin_cardinality:int=field(default=0,validator=[instance_of(int),ge(0)])"The minimum required cardinality."max_cardinality:int=field(validator=instance_of(int))"The maximum allowed cardinality."@max_cardinality.defaultdef_default_max_cardinality(self):"""Use the number of involved parameters as the upper limit by default."""returnlen(self.parameters)def__attrs_post_init__(self):"""Validate the cardinality bounds. Raises: ValueError: If the provided cardinality bounds are invalid. ValueError: If the provided cardinality bounds impose no constraint. """ifself.min_cardinality>self.max_cardinality:raiseValueError(f"The lower cardinality bound cannot be larger than the upper bound. "f"Provided values: {self.max_cardinality=}, {self.min_cardinality=}.")ifself.max_cardinality>len(self.parameters):raiseValueError(f"The cardinality bound cannot exceed the number of parameters. "f"Provided values: {self.max_cardinality=}, {len(self.parameters)=}.")ifself.min_cardinality==0andself.max_cardinality==len(self.parameters):raiseValueError(f"No constraint of type `{self.__class__.__name__}' is required "f"when the lower cardinality bound is zero and the upper bound equals "f"the number of parameters. Provided values: {self.min_cardinality=}, "f"{self.max_cardinality=}, {len(self.parameters)=}")
[docs]classContinuousNonlinearConstraint(ContinuousConstraint,ABC):"""Abstract base class for continuous nonlinear constraints."""
# >>>>> Deprecation handling_hook=converter.get_structure_hook(Constraint)def_deprecate_legacy_classes(dct:dict[str,Any],_)->Constraint:"""Enable constraint configs using legacy class names."""ifdct["type"]=="ContinuousLinearEqualityConstraint":dct.pop("type")returnContinuousLinearEqualityConstraint(**dct)elifdct["type"]=="ContinuousLinearInequalityConstraint":dct.pop("type")returnContinuousLinearInequalityConstraint(**dct)return_hook(dct,_)converter.register_structure_hook_func(lambdac:cisConstraint,_deprecate_legacy_classes)# <<<<< Deprecation handling# Collect leftover original slotted classes processed by `attrs.define`gc.collect()