CustomONNXSurrogate¶
- class baybe.surrogates.custom.CustomONNXSurrogate[source]¶
Bases:
Surrogate
A wrapper class for custom pretrained surrogate models.
Note that these surrogates cannot be retrained.
Public methods
__init__
(*, onnx_input_name, onnx_str)Method generated by attrs for class CustomONNXSurrogate.
Instantiate the ONNX inference session.
fit
(searchspace, train_x, train_y)Train the surrogate model on the provided data.
from_dict
(dictionary)Create an object from its dictionary representation.
from_json
(string)Create an object from its JSON representation.
posterior
(candidates)Evaluate the surrogate model at the given candidate points.
Create the botorch-ready representation of the model.
to_dict
()Create an object's dictionary representation.
to_json
()Create an object's JSON representation.
validate_compatibility
(searchspace)Validate if the class is compatible with a given search space.
Public attributes and properties
The input name used for constructing the ONNX str.
The ONNX byte str representing the model.
Class variable encoding whether or not a joint posterior is calculated.
Class variable encoding whether or not the surrogate supports transfer learning.
- __init__(*, onnx_input_name: str, onnx_str: bytes)¶
Method generated by attrs for class CustomONNXSurrogate.
For details on the parameters, see Public attributes and properties.
- fit(searchspace: SearchSpace, train_x: Tensor, train_y: Tensor)¶
Train the surrogate model on the provided data.
- Parameters:
searchspace (SearchSpace) – The search space in which experiments are conducted.
train_x (Tensor) – The training data points.
train_y (Tensor) – The training data labels.
- Raises:
ValueError – If the search space contains task parameters but the selected surrogate model type does not support transfer learning.
NotImplementedError – When using a continuous search space and a non-GP model.
- Return type:
None
- posterior(candidates: Tensor)¶
Evaluate the surrogate model at the given candidate points.
- Parameters:
candidates (Tensor) – The candidate points, represented as a tensor of shape
(*t, q, d)
, wheret
denotes the “t-batch” shape,q
denotes the “q-batch” shape, andd
is the input dimension. For more details about batch shapes, see: https://botorch.org/docs/batching- Return type:
tuple[Tensor, Tensor]
- Returns:
The posterior means and posterior covariance matrices of the t-batched candidate points.
- to_botorch()¶
Create the botorch-ready representation of the model.
- Return type:
Model
- to_json()¶
Create an object’s JSON representation.
- Return type:
- Returns:
The JSON representation as a string.
- classmethod validate_compatibility(searchspace: SearchSpace)[source]¶
Validate if the class is compatible with a given search space.
- Parameters:
searchspace (
SearchSpace
) – The search space to be tested for compatibility.- Raises:
TypeError – If the search space is incompatible with the class.
- Return type: