"""Validation functionality for surrogates."""from__future__importannotationsfromcollections.abcimportCallablefromtypingimportAnyimportcattrsfromcattrsimportClassValidationErrorfromcattrs.strategiesimportconfigure_union_passthroughfrombaybe.surrogates.baseimportSurrogate
[docs]defvalidate_custom_architecture_cls(model_cls:type)->None:"""Validate a custom architecture to have the correct attributes. Args: model_cls: The user defined model class. Raises: ValueError: When model_cls does not have _fit or _posterior. ValueError: When _fit or _posterior is not a callable method. ValueError: When _fit does not have the required signature. ValueError: When _posterior does not have the required signature. """# Methods must existifnot(hasattr(model_cls,"_fit")andhasattr(model_cls,"_posterior")):raiseValueError("`_fit` and a `_posterior` must exist for custom architectures")fit=model_cls._fitposterior=model_cls._posterior# They must be methodsifnot(callable(fit)andcallable(posterior)):raiseValueError("`_fit` and a `_posterior` must be methods for custom architectures")# Methods must have the correct argumentsparams=fit.__code__.co_varnames[:fit.__code__.co_argcount]ifparams!=Surrogate._fit.__code__.co_varnames:raiseValueError("Invalid args in `_fit` method definition for custom architecture. ""Please refer to Surrogate._fit for the required function signature.")params=posterior.__code__.co_varnames[:posterior.__code__.co_argcount]ifparams!=Surrogate._posterior.__code__.co_varnames:raiseValueError("Invalid args in `_posterior` method definition for custom architecture. ""Please refer to Surrogate._posterior for the required function signature.")
# Create a strict type validation convertertype_validation_converter=cattrs.Converter(forbid_extra_keys=True)"""Converter used for strict type validation."""configure_union_passthrough(int|float|str|None,type_validation_converter)@type_validation_converter.register_structure_hookdef_strict_int_structure_hook(obj:Any,_:type[int])->int:ifisinstance(obj,int)andnotisinstance(obj,bool):# Exclude boolsreturnobjraiseValueError(f"Value '{obj}' (type: {type(obj).__name__}) is not a valid integer. ""Only actual 'int' instances are accepted.")@type_validation_converter.register_structure_hookdef_strict_float_structure_hook(obj:Any,_:type[float])->float:ifisinstance(obj,float):returnobjraiseValueError(f"Value '{obj}' (type: {type(obj).__name__}) is not a valid float. ""Only actual 'float' instances are accepted.")@type_validation_converter.register_structure_hookdef_strict_bool_structure_hook(obj:Any,_:type[bool])->bool:ifisinstance(obj,bool):returnobjraiseValueError(f"Value '{obj}' (type: {type(obj).__name__}) is not a valid boolean. ""Only actual 'bool' instances (True, False) are accepted.")
[docs]defmake_dict_validator(specification:type)->Callable:"""Construct an attrs dictionary validator based on a ``TypedDict``. Args: specification: Describes allowed keys and corresponding value types. Returns: An attrs compatible validator. """defvalidate_model_params(_instance:Any,attr:Any,value:dict)->None:"""Validate attrs attribute using cattrs with an extremely strict int hook."""try:type_validation_converter.structure(value,specification)exceptClassValidationErrorasex:raiseTypeError(f"The provided dictionary for '{attr.name}' is invalid.")fromexreturnvalidate_model_params