"""A wrapper class for synthetic BoTorch test functions."""importtorchfrombotorch.test_functionsimportSyntheticTestFunction
[docs]defbotorch_function_wrapper(test_function:SyntheticTestFunction):"""Turn a BoTorch test function into a format accepted by lookup in simulations. See :mod:`baybe.simulation` for details. Args: test_function: The synthetic test function from BoTorch. See https://botorch.org/api/test_functions.html. Returns: A wrapped version of the provided function. """defwrapper(*x:float)->float:# Cast the provided list of floats to a tensor.x_tensor=torch.tensor(x)result=test_function.forward(x_tensor)# We do not need to return a tuple here.returnfloat(result)returnwrapper