Source code for baybe.surrogates.custom

"""Functionality for building custom surrogates.

Note that ONNX surrogate models cannot be retrained. However, having the surrogates
raise a ``NotImplementedError`` would currently break the code since
:class:`baybe.recommenders.pure.bayesian.base.BayesianRecommender` assumes that
surrogates can be trained and attempts to do so for each new DOE iteration.

It is planned to solve this issue in the future.
"""

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, ClassVar

from attrs import define, field, validators

from baybe.parameters import (
    CategoricalEncoding,
    CategoricalParameter,
    CustomDiscreteParameter,
    NumericalContinuousParameter,
    NumericalDiscreteParameter,
    TaskParameter,
)
from baybe.searchspace import SearchSpace
from baybe.serialization.core import block_serialization_hook, converter
from baybe.surrogates.base import Surrogate
from baybe.surrogates.utils import batchify, catch_constant_targets
from baybe.surrogates.validation import validate_custom_architecture_cls
from baybe.utils.numerical import DTypeFloatONNX

if TYPE_CHECKING:
    import onnxruntime as ort
    from torch import Tensor


[docs] def register_custom_architecture( joint_posterior_attr: bool = False, constant_target_catching: bool = True, batchify_posterior: bool = True, ) -> Callable: """Wrap a given custom model architecture class into a ```Surrogate```. Args: joint_posterior_attr: Boolean indicating if the model returns a posterior distribution jointly across candidates or on individual points. constant_target_catching: Boolean indicating if the model cannot handle constant target values and needs the @catch_constant_targets decorator. batchify_posterior: Boolean indicating if the model is incompatible with t- and q-batching and needs the @batchify decorator for its posterior. Returns: A function that wraps around a model class based on the specifications. """ def construct_custom_architecture(model_cls): """Construct a surrogate class wrapped around the custom class.""" validate_custom_architecture_cls(model_cls) class CustomArchitectureSurrogate(Surrogate): """Wraps around a custom architecture class.""" joint_posterior: ClassVar[bool] = joint_posterior_attr supports_transfer_learning: ClassVar[bool] = False def __init__(self, *args, **kwargs): self._model = model_cls(*args, **kwargs) def _fit( self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor ) -> None: return self._model._fit(searchspace, train_x, train_y) def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: return self._model._posterior(candidates) def __get_attribute__(self, attr): """Access the attributes of the class instance if available. If the attributes are not available, it uses the attributes of the internal model instance. """ # Try to retrieve the attribute in the class try: val = super().__getattribute__(attr) except AttributeError: pass else: return val # If the attribute is not overwritten, use that of the internal model return self._model.__getattribute__(attr) # Catch constant targets if needed cls = ( catch_constant_targets(CustomArchitectureSurrogate) if constant_target_catching else CustomArchitectureSurrogate ) # batchify posterior if needed if batchify_posterior: cls._posterior = batchify(cls._posterior) # Block serialization of custom architectures converter.register_unstructure_hook( CustomArchitectureSurrogate, block_serialization_hook ) return cls return construct_custom_architecture
[docs] @define(kw_only=True) class CustomONNXSurrogate(Surrogate): """A wrapper class for custom pretrained surrogate models. Note that these surrogates cannot be retrained. """ # Class variables joint_posterior: ClassVar[bool] = False # See base class. supports_transfer_learning: ClassVar[bool] = False # See base class. # Object variables onnx_input_name: str = field(validator=validators.instance_of(str)) """The input name used for constructing the ONNX str.""" onnx_str: bytes = field(validator=validators.instance_of(bytes)) """The ONNX byte str representing the model.""" # TODO: type should be `onnxruntime.InferenceSession` but is currently # omitted due to: https://github.com/python-attrs/cattrs/issues/531 _model = field(init=False, eq=False) """The actual model."""
[docs] @_model.default def default_model(self) -> ort.InferenceSession: """Instantiate the ONNX inference session.""" from baybe._optional.onnx import onnxruntime as ort try: return ort.InferenceSession(self.onnx_str) except Exception as exc: raise ValueError("Invalid ONNX string") from exc
@batchify def _posterior(self, candidates: Tensor) -> tuple[Tensor, Tensor]: import torch from baybe.utils.torch import DTypeFloatTorch model_inputs = {self.onnx_input_name: candidates.numpy().astype(DTypeFloatONNX)} results = self._model.run(None, model_inputs) # IMPROVE: At the moment, we assume that the second model output contains # standard deviations. Currently, most available ONNX converters care # about the mean only and it's not clear how this will be handled in the # future. Once there are more choices available, this should be revisited. return ( torch.from_numpy(results[0]).to(DTypeFloatTorch), torch.from_numpy(results[1]).pow(2).to(DTypeFloatTorch), ) def _fit(self, searchspace: SearchSpace, train_x: Tensor, train_y: Tensor) -> None: # TODO: This method actually needs to raise a NotImplementedError because # ONNX surrogate models cannot be retrained. However, this would currently # break the code since `BayesianRecommender` assumes that surrogates # can be trained and attempts to do so for each new DOE iteration. # Therefore, a refactoring is required in order to properly incorporate # "static" surrogates and account for them in the exposed APIs. pass
[docs] @classmethod def validate_compatibility(cls, searchspace: SearchSpace) -> None: """Validate if the class is compatible with a given search space. Args: searchspace: The search space to be tested for compatibility. Raises: TypeError: If the search space is incompatible with the class. """ if not all( isinstance( p, ( NumericalContinuousParameter, NumericalDiscreteParameter, TaskParameter, ), ) or (isinstance(p, CustomDiscreteParameter) and not p.decorrelate) or ( isinstance(p, CategoricalParameter) and p.encoding is CategoricalEncoding.INT ) for p in searchspace.parameters ): raise TypeError( f"To prevent potential hard-to-detect bugs that stem from wrong " f"wiring of model inputs, {cls.__name__} " f"is currently restricted for use with parameters that have " f"a one-dimensional computational representation or " f"{CustomDiscreteParameter.__name__}." )