"""Linear surrogates.Currently, the documentation for this surrogate is not available. This is due to a bugin our documentation tool, see https://github.com/sphinx-doc/sphinx/issues/11750.Since we plan to refactor the surrogates, this part of the documentation will beavailable in the future. Thus, please have a look in the source code directly."""from__future__importannotationsfromtypingimportTYPE_CHECKING,Any,ClassVarfromattrimportdefine,fieldfromsklearn.linear_modelimportARDRegressionfrombaybe.searchspaceimportSearchSpacefrombaybe.surrogates.baseimportSurrogatefrombaybe.surrogates.utilsimportautoscale,batchify,catch_constant_targetsfrombaybe.surrogates.validationimportget_model_params_validatorifTYPE_CHECKING:fromtorchimportTensor
[docs]@catch_constant_targets@autoscale@define(slots=False)classBayesianLinearSurrogate(Surrogate):"""A Bayesian linear regression surrogate model."""# Class variablesjoint_posterior:ClassVar[bool]=False# See base class.supports_transfer_learning:ClassVar[bool]=False# See base class.# Object variablesmodel_params:dict[str,Any]=field(factory=dict,converter=dict,validator=get_model_params_validator(ARDRegression.__init__),)"""Optional model parameter that will be passed to the surrogate constructor."""_model:ARDRegression|None=field(init=False,default=None,eq=False)"""The actual model."""@batchifydef_posterior(self,candidates:Tensor)->tuple[Tensor,Tensor]:# See base class.importtorch# Get predictionsdists=self._model.predict(candidates.numpy(),return_std=True)# Split into posterior mean and variancemean=torch.from_numpy(dists[0])var=torch.from_numpy(dists[1]).pow(2)returnmean,vardef_fit(self,searchspace:SearchSpace,train_x:Tensor,train_y:Tensor)->None:# See base class.self._model=ARDRegression(**(self.model_params))self._model.fit(train_x,train_y.ravel())