Source code for baybe.surrogates.random_forest
"""Random forest surrogates."""
from __future__ import annotations
from collections.abc import Collection
from typing import TYPE_CHECKING, ClassVar, Literal, Protocol, TypedDict
import numpy as np
import numpy.typing as npt
from attrs import define, field
from numpy.random import RandomState
from typing_extensions import override
from baybe.parameters.base import Parameter
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify_ensemble_predictor, catch_constant_targets
from baybe.surrogates.validation import make_dict_validator
from baybe.utils.conversion import to_string
if TYPE_CHECKING:
from botorch.models.ensemble import EnsemblePosterior
from botorch.models.transforms.input import InputTransform
from botorch.models.transforms.outcome import OutcomeTransform
from torch import Tensor
class _RandomForestRegressorParams(TypedDict, total=False):
"""Optional RandomForestRegressor parameters.
See :class:`~sklearn.ensemble.RandomForestRegressor`.
"""
n_estimators: int
criterion: Literal["squared_error", "absolute_error", "friedman_mse", "poisson"]
max_depth: int
min_samples_split: int | float
min_samples_leaf: int | float
min_weight_fraction_leaf: float
max_features: Literal["sqrt", "log2"] | int | float | None
max_leaf_nodes: int | None
min_impurity_decrease: float
bootstrap: bool
oob_score: bool
n_jobs: int | None
random_state: int | RandomState | None
verbose: int
warm_start: bool
ccp_alpha: float
max_samples: int | float | None
monotonic_cst: npt.ArrayLike | int | None
class _Predictor(Protocol):
"""A basic predictor."""
def predict(self, x: np.ndarray, /) -> np.ndarray: ...
[docs]
@catch_constant_targets
@define
class RandomForestSurrogate(Surrogate):
"""A random forest surrogate model."""
supports_transfer_learning: ClassVar[bool] = False
# See base class.
model_params: _RandomForestRegressorParams = field(
factory=dict,
converter=dict,
validator=make_dict_validator(_RandomForestRegressorParams),
)
"""Optional model parameter that will be passed to the surrogate constructor.
For allowed keys and values, see :class:`~sklearn.ensemble.RandomForestRegressor`.
"""
# TODO: type should be `RandomForestRegressor | None` but is currently omitted due
# to: https://github.com/python-attrs/cattrs/issues/531
_model = field(init=False, default=None, eq=False)
"""The actual model."""
@override
@staticmethod
def _make_parameter_scaler_factory(
parameter: Parameter,
) -> type[InputTransform] | None:
# Tree-like models do not require any input scaling
return None
@override
@staticmethod
def _make_target_scaler_factory() -> type[OutcomeTransform] | None:
# Tree-like models do not require any output scaling
return None
@override
def _posterior(self, candidates_comp_scaled: Tensor, /) -> EnsemblePosterior:
from botorch.models.ensemble import EnsemblePosterior
@batchify_ensemble_predictor
def predict(candidates_comp_scaled: Tensor) -> Tensor:
"""Make the end-to-end ensemble prediction."""
import torch
# FIXME[typing]: It seems there is currently no better way to inform the
# type checker that the attribute is available at the time of the
# function call
assert self._model is not None
return torch.from_numpy(
self._predict_ensemble(
self._model.estimators_, candidates_comp_scaled.numpy()
)
)
return EnsemblePosterior(predict(candidates_comp_scaled).unsqueeze(-1))
@staticmethod
def _predict_ensemble(
predictors: Collection[_Predictor], candidates: np.ndarray
) -> np.ndarray:
"""Evaluate an ensemble of predictors on a given candidate set."""
# Extract shapes
n_candidates = len(candidates)
n_estimators = len(predictors)
# Evaluate all trees
predictions = np.zeros((n_estimators, n_candidates))
for p, predictor in enumerate(predictors):
predictions[p] = predictor.predict(candidates)
return predictions
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
from sklearn.ensemble import RandomForestRegressor
self._model = RandomForestRegressor(**(self.model_params))
self._model.fit(train_x.numpy(), train_y.numpy().ravel())
@override
def __str__(self) -> str:
fields = [to_string("Model Params", self.model_params, single_line=True)]
return to_string(super().__str__(), *fields)