"""Random forest surrogates.Currently, the documentation for this surrogate is not available. This is due to a bugin our documentation tool, see https://github.com/sphinx-doc/sphinx/issues/11750.Since we plan to refactor the surrogates, this part of the documentation will beavailable in the future. Thus, please have a look in the source code directly."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,ClassVarimportnumpyasnpfromattrimportdefine,fieldfromsklearn.ensembleimportRandomForestRegressorfrombaybe.searchspaceimportSearchSpacefrombaybe.surrogates.baseimportSurrogatefrombaybe.surrogates.utilsimportautoscale,batchify,catch_constant_targetsfrombaybe.surrogates.validationimportget_model_params_validatorifTYPE_CHECKING:fromtorchimportTensor
[docs]@catch_constant_targets@autoscale@define(slots=False)classRandomForestSurrogate(Surrogate):"""A random forest surrogate model."""# Class variablesjoint_posterior:ClassVar[bool]=False# See base class.supports_transfer_learning:ClassVar[bool]=False# See base class.# Object variablesmodel_params:dict[str,Any]=field(factory=dict,converter=dict,validator=get_model_params_validator(RandomForestRegressor.__init__),)"""Optional model parameter that will be passed to the surrogate constructor."""_model:RandomForestRegressor|None=field(init=False,default=None,eq=False)"""The actual model."""@batchifydef_posterior(self,candidates:Tensor)->tuple[Tensor,Tensor]:# See base class.importtorch# Evaluate all trees# NOTE: explicit conversion to ndarray is needed due to a pytorch issue:# https://github.com/pytorch/pytorch/pull/51731# https://github.com/pytorch/pytorch/issues/13918predictions=torch.from_numpy(np.asarray([self._model.estimators_[tree].predict(candidates)fortreeinrange(self._model.n_estimators)]))# Compute posterior mean and variancemean=predictions.mean(dim=0)var=predictions.var(dim=0)returnmean,vardef_fit(self,searchspace:SearchSpace,train_x:Tensor,train_y:Tensor)->None:# See base class.self._model=RandomForestRegressor(**(self.model_params))self._model.fit(train_x,train_y.ravel())