Source code for baybe.recommenders.pure.nonpredictive.sampling
"""Recommenders based on sampling."""importosfromenumimportEnumfromtypingimportClassVarimportnumpyasnpimportpandasaspdfromattrsimportdefine,field,fieldsfromattrs.validatorsimportinstance_offromtyping_extensionsimportoverridefrombaybe._optional.infoimportFPSAMPLE_INSTALLEDfrombaybe.recommenders.pure.nonpredictive.baseimportNonPredictiveRecommenderfrombaybe.searchspaceimportSearchSpace,SearchSpaceType,SubspaceDiscretefrombaybe.utils.booleanimportstrtoboolfrombaybe.utils.conversionimportto_stringfrombaybe.utils.sampling_algorithmsimportfarthest_point_samplingFPSAMPLE_USED=strtobool(os.environ.get("BAYBE_USE_FPSAMPLE",str(FPSAMPLE_INSTALLED)))
[docs]classRandomRecommender(NonPredictiveRecommender):"""Recommends experiments randomly."""# Class variablescompatibility:ClassVar[SearchSpaceType]=SearchSpaceType.HYBRID# See base class.@overridedef_recommend_hybrid(self,searchspace:SearchSpace,candidates_exp:pd.DataFrame,batch_size:int,)->pd.DataFrame:ifsearchspace.type==SearchSpaceType.DISCRETE:returncandidates_exp.sample(batch_size)cont_random=searchspace.continuous.sample_uniform(batch_size=batch_size)ifsearchspace.type==SearchSpaceType.CONTINUOUS:returncont_randomdisc_candidates,_=searchspace.discrete.get_candidates()# TODO decide mechanism if number of possible discrete candidates is smaller# than batch sizedisc_random=disc_candidates.sample(n=batch_size,replace=len(disc_candidates)<batch_size,)cont_random.index=disc_random.indexreturnpd.concat([disc_random,cont_random],axis=1)@overridedef__str__(self)->str:fields=[to_string("Compatibility",self.compatibility,single_line=True)]returnto_string(self.__class__.__name__,*fields)
[docs]classFPSInitialization(Enum):"""Initialization methods for farthest point sampling."""FARTHEST="farthest""""Selects the first two points with the largest distance."""RANDOM="random""""Selects the first point uniformly at random."""
[docs]@defineclassFPSRecommender(NonPredictiveRecommender):"""An initial recommender that selects candidates via Farthest Point Sampling. If the optional package `fpsample` is installed, its implementation will be used, otherwise a custom fallback implementation is used. The use of a specific implementation can be enforced by setting the environment variable 'BAYBE_USE_FPSAMPLE'. """# Class variablescompatibility:ClassVar[SearchSpaceType]=SearchSpaceType.DISCRETE# See base class.initialization:FPSInitialization=field(default=FPSInitialization.FARTHEST,converter=FPSInitialization)"""See :func:`~baybe.utils.sampling_algorithms.farthest_point_sampling`. If the optional package 'fpsample' is used, only :attr:`~baybe.recommenders.pure.nonpredictive.sampling.FPSInitialization.FARTHEST` is supported. """random_tie_break:bool=field(validator=instance_of(bool),kw_only=True)"""See :func:`~baybe.utils.sampling_algorithms.farthest_point_sampling`. If the optional package 'fpsample' is used, only ``False`` is supported. """@initialization.validatordef_validate_initialization(self,_,value):ifFPSAMPLE_USEDandvalueisnotFPSInitialization.FARTHEST:raiseValueError(f"{self.__class__.__name__} is using the optional 'fpsample' "f"package, which does not support '{self.initialization}'. "f"Please choose a supported initialization method or bypass `fpsample` "f"usage by setting the environment variable "f"BAYBE_USE_FPSAMPLE.")@random_tie_break.defaultdef_default_random_tie_break(self)->bool:returnself.initializationisnotFPSInitialization.FARTHEST@random_tie_break.validatordef_validate_random_tie_break(self,_,value):ifFPSAMPLE_USEDandvalue:raiseValueError(f"'{self.__class__.__name__}' is using the optional 'fpsample' "f"package, which does not support random tie-breaking. "f"To disable the mechanism, set "f"'{fields(self.__class__).random_tie_break.name}=False' or bypass "f"`fpsample` usage by setting the environment variable "f"BAYBE_USE_FPSAMPLE.")@overridedef_recommend_discrete(self,subspace_discrete:SubspaceDiscrete,candidates_exp:pd.DataFrame,batch_size:int,)->pd.Index:# Fit scaler on entire search spacefromsklearn.preprocessingimportStandardScaler# TODO [Scaling]: scaling should be handled by search space objectscaler=StandardScaler()scaler.fit(subspace_discrete.comp_rep)# Scale and samplecandidates_comp=subspace_discrete.transform(candidates_exp)candidates_scaled=np.ascontiguousarray(scaler.transform(candidates_comp))ifFPSAMPLE_USED:frombaybe._optional.fpsampleimportfps_samplingilocs=fps_sampling(candidates_scaled,n_samples=batch_size,)else:# Custom implementation as fallbackilocs=farthest_point_sampling(candidates_scaled,batch_size,initialization=self.initialization.value,random_tie_break=self.random_tie_break,)returncandidates_comp.index[ilocs]@overridedef__str__(self)->str:fields=[to_string("Compatibility",self.compatibility,single_line=True)]returnto_string(self.__class__.__name__,*fields)