"""Base classes for all acquisition functions."""from__future__importannotationsimportwarningsfromabcimportABCfrominspectimportsignaturefromtypingimportClassVarimportpandasaspdfromattrsimportdefinefrombaybe.searchspaceimportSearchSpacefrombaybe.serialization.coreimport(converter,get_base_structure_hook,unstructure_base,)frombaybe.serialization.mixinimportSerialMixinfrombaybe.surrogates.baseimportSurrogatefrombaybe.utils.basicimportclassproperty,match_attributesfrombaybe.utils.booleanimportis_abstractfrombaybe.utils.dataframeimportto_tensor
[docs]@define(frozen=True)classAcquisitionFunction(ABC,SerialMixin):"""Abstract base class for all acquisition functions."""abbreviation:ClassVar[str]"""An alternative name for type resolution."""@classpropertydefis_mc(cls)->bool:"""Flag indicating whether this is a Monte-Carlo acquisition function."""returncls.abbreviation.startswith("q")@classpropertydef_non_botorch_attrs(cls)->tuple[str,...]:"""Names of attributes that are not passed to the BoTorch constructor."""return()
[docs]defto_botorch(self,surrogate:Surrogate,searchspace:SearchSpace,train_x:pd.DataFrame,train_y:pd.DataFrame,):"""Create the botorch-ready representation of the function."""importbotorch.acquisitionasbotorch_acqf_module# Retrieve corresponding botorch classacqf_cls=getattr(botorch_acqf_module,self.__class__.__name__)# Match relevant attributesparams_dict=match_attributes(self,acqf_cls.__init__,ignore=self._non_botorch_attrs)[0]# Collect remaining (context-specific) parameterssignature_params=signature(acqf_cls).parametersadditional_params={}if"model"insignature_params:additional_params["model"]=surrogate.to_botorch()if"best_f"insignature_params:additional_params["best_f"]=train_y.max().item()if"X_baseline"insignature_params:additional_params["X_baseline"]=to_tensor(train_x)if"mc_points"insignature_params:additional_params["mc_points"]=to_tensor(self.get_integration_points(searchspace)# type: ignore[attr-defined])params_dict.update(additional_params)returnacqf_cls(**params_dict)
# Register de-/serialization hooksdef_add_deprecation_hook(hook):"""Add deprecation warnings to the default hook. Used for backward compatibility only and will be removed in future versions. """defadded_deprecation_hook(val:dict|str,cls:type):# Backwards-compatibility needs to be ensured only for deserialization from# base class using string-based type specifiers as listed below,# since the concrete classes were available only after the change.ifis_abstract(cls):UCB_DEPRECATIONS={"VarUCB":"UpperConfidenceBound","qVarUCB":"qUpperConfidenceBound",}if(entry:=valifisinstance(val,str)elseval["type"])inUCB_DEPRECATIONS:warnings.warn(f"The use of `{entry}` is deprecated and will be disabled in a "f"future version. To get the same outcome, use the new "f"`{UCB_DEPRECATIONS[entry]}` class instead with a beta of 100.0.",DeprecationWarning,)val={"type":UCB_DEPRECATIONS[entry],"beta":100.0}returnhook(val,cls)returnadded_deprecation_hookconverter.register_structure_hook(AcquisitionFunction,_add_deprecation_hook(get_base_structure_hook(AcquisitionFunction)),)converter.register_unstructure_hook(AcquisitionFunction,unstructure_base)