Source code for baybe.surrogates.validation
"""Validation functionality for surrogates."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any
import cattrs
from cattrs import ClassValidationError
from cattrs.strategies import configure_union_passthrough
from baybe.surrogates.base import Surrogate
[docs]
def validate_custom_architecture_cls(model_cls: type) -> None:
"""Validate a custom architecture to have the correct attributes.
Args:
model_cls: The user defined model class.
Raises:
ValueError: When model_cls does not have _fit or _posterior.
ValueError: When _fit or _posterior is not a callable method.
ValueError: When _fit does not have the required signature.
ValueError: When _posterior does not have the required signature.
"""
# Methods must exist
if not (hasattr(model_cls, "_fit") and hasattr(model_cls, "_posterior")):
raise ValueError(
"`_fit` and a `_posterior` must exist for custom architectures"
)
fit = model_cls._fit
posterior = model_cls._posterior
# They must be methods
if not (callable(fit) and callable(posterior)):
raise ValueError(
"`_fit` and a `_posterior` must be methods for custom architectures"
)
# Methods must have the correct arguments
params = fit.__code__.co_varnames[: fit.__code__.co_argcount]
if params != Surrogate._fit.__code__.co_varnames:
raise ValueError(
"Invalid args in `_fit` method definition for custom architecture. "
"Please refer to Surrogate._fit for the required function signature."
)
params = posterior.__code__.co_varnames[: posterior.__code__.co_argcount]
if params != Surrogate._posterior.__code__.co_varnames:
raise ValueError(
"Invalid args in `_posterior` method definition for custom architecture. "
"Please refer to Surrogate._posterior for the required function signature."
)
# Create a strict type validation converter
type_validation_converter = cattrs.Converter(forbid_extra_keys=True)
"""Converter used for strict type validation."""
configure_union_passthrough(int | float | str | None, type_validation_converter)
@type_validation_converter.register_structure_hook
def _strict_int_structure_hook(obj: Any, _: type[int]) -> int:
if isinstance(obj, int) and not isinstance(obj, bool): # Exclude bools
return obj
raise ValueError(
f"Value '{obj}' (type: {type(obj).__name__}) is not a valid integer. "
"Only actual 'int' instances are accepted."
)
@type_validation_converter.register_structure_hook
def _strict_float_structure_hook(obj: Any, _: type[float]) -> float:
if isinstance(obj, float):
return obj
raise ValueError(
f"Value '{obj}' (type: {type(obj).__name__}) is not a valid float. "
"Only actual 'float' instances are accepted."
)
@type_validation_converter.register_structure_hook
def _strict_bool_structure_hook(obj: Any, _: type[bool]) -> bool:
if isinstance(obj, bool):
return obj
raise ValueError(
f"Value '{obj}' (type: {type(obj).__name__}) is not a valid boolean. "
"Only actual 'bool' instances (True, False) are accepted."
)
[docs]
def make_dict_validator(specification: type) -> Callable:
"""Construct an attrs dictionary validator based on a ``TypedDict``.
Args:
specification: Describes allowed keys and corresponding value types.
Returns:
An attrs compatible validator.
"""
def validate_model_params(_instance: Any, attr: Any, value: dict) -> None:
"""Validate attrs attribute using cattrs with an extremely strict int hook."""
try:
type_validation_converter.structure(value, specification)
except ClassValidationError as ex:
raise TypeError(
f"The provided dictionary for '{attr.name}' is invalid."
) from ex
return validate_model_params