Source code for baybe.surrogates.naive
"""Naive surrogates."""
from __future__ import annotations
import gc
from typing import TYPE_CHECKING, ClassVar
from attrs import define, field
from typing_extensions import override
from baybe.surrogates.base import IndependentGaussianSurrogate
from baybe.surrogates.utils import batchify_mean_var_prediction
if TYPE_CHECKING:
from torch import Tensor
[docs]
@define
class MeanPredictionSurrogate(IndependentGaussianSurrogate):
"""A trivial surrogate model.
It provides the average value of the training targets
as posterior mean and a (data-independent) constant posterior variance.
"""
supports_transfer_learning: ClassVar[bool] = False
# See base class.
_model: float | None = field(init=False, default=None, eq=False)
"""The estimated posterior mean value of the training targets."""
@override
@batchify_mean_var_prediction
def _estimate_moments(
self, candidates_comp_scaled: Tensor, /
) -> tuple[Tensor, Tensor]:
import torch
# TODO: use target value bounds for covariance scaling when explicitly provided
mean = self._model * torch.ones([len(candidates_comp_scaled)]) # type: ignore[operator]
var = torch.ones(len(candidates_comp_scaled))
return mean, var
@override
def _fit(self, train_x: Tensor, train_y: Tensor) -> None:
self._model = train_y.mean().item()
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()