"""NGBoost surrogates.Currently, the documentation for this surrogate is not available. This is due to a bugin our documentation tool, see https://github.com/sphinx-doc/sphinx/issues/11750.Since we plan to refactor the surrogates, this part of the documentation will beavailable in the future. Thus, please have a look in the source code directly."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,ClassVarfromattrimportdefine,fieldfromngboostimportNGBRegressorfrombaybe.searchspaceimportSearchSpacefrombaybe.surrogates.baseimportSurrogatefrombaybe.surrogates.utilsimportautoscale,batchify,catch_constant_targetsfrombaybe.surrogates.validationimportget_model_params_validatorifTYPE_CHECKING:fromtorchimportTensor
[docs]@catch_constant_targets@autoscale@define(slots=False)classNGBoostSurrogate(Surrogate):"""A natural-gradient-boosting surrogate model."""# Class variablesjoint_posterior:ClassVar[bool]=False# See base class.supports_transfer_learning:ClassVar[bool]=False# See base class._default_model_params:ClassVar[dict]={"n_estimators":25,"verbose":False}"""Class variable encoding the default model parameters."""# Object variablesmodel_params:dict[str,Any]=field(factory=dict,converter=dict,validator=get_model_params_validator(NGBRegressor.__init__),)"""Optional model parameter that will be passed to the surrogate constructor."""_model:NGBRegressor|None=field(init=False,default=None,eq=False)"""The actual model."""def__attrs_post_init__(self):self.model_params={**self._default_model_params,**self.model_params}@batchifydef_posterior(self,candidates:Tensor)->tuple[Tensor,Tensor]:# See base class.importtorch# Get predictionsdists=self._model.pred_dist(candidates)# Split into posterior mean and variancemean=torch.from_numpy(dists.mean())var=torch.from_numpy(dists.var)returnmean,vardef_fit(self,searchspace:SearchSpace,train_x:Tensor,train_y:Tensor)->None:# See base class.self._model=NGBRegressor(**(self.model_params)).fit(train_x,train_y.ravel())