Source code for baybe.utils.validation
"""Validation utilities."""
from __future__ import annotations
import math
from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING, Any
import numpy as np
import pandas as pd
from attrs import Attribute
from baybe.exceptions import IncompleteMeasurementsError
if TYPE_CHECKING:
from baybe.objectives.base import Objective
from baybe.parameters.base import Parameter
from baybe.targets.base import Target
[docs]
def validate_not_nan(self: Any, attribute: Attribute, value: Any) -> None:
"""Attrs-compatible validator to forbid 'nan' values."""
if isinstance(value, float) and math.isnan(value):
raise ValueError(
f"The value passed to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' cannot be 'nan'."
)
def _make_restricted_float_validator(
allow_nan: bool, allow_inf: bool
) -> Callable[[Any, Attribute, Any], None]:
"""Make an attrs-compatible validator for restricted floats.
Args:
allow_nan: If False, validated values cannot be 'nan'.
allow_inf: If False, validated values cannot be 'inf' or '-inf'.
Raises:
ValueError: If no float range restriction is in place.
Returns:
The validator.
"""
if allow_nan and allow_inf:
raise ValueError(
"The requested validator would not restrict the float range. "
"Hence, you can use `attrs.validators.instance_of(float)` instead."
)
def validator(self: Any, attribute: Attribute, value: Any) -> None:
if not isinstance(value, float):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' must be of type 'float'. "
f"Given: {value} (type: {type(value)})"
)
if not allow_inf and math.isinf(value):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' cannot be 'inf' or '-inf'."
)
if not allow_nan and math.isnan(value):
raise ValueError(
f"Values assigned to attribute '{attribute.name}' of class "
f"'{self.__class__.__name__}' cannot be 'nan'."
)
return validator
finite_float = _make_restricted_float_validator(allow_nan=False, allow_inf=False)
"""Validator for finite (i.e., non-nan and non-infinite) floats."""
non_nan_float = _make_restricted_float_validator(allow_nan=False, allow_inf=True)
"""Validator for non-nan floats."""
non_inf_float = _make_restricted_float_validator(allow_nan=True, allow_inf=False)
"""Validator for non-infinite floats."""
[docs]
def validate_target_input(data: pd.DataFrame, targets: Iterable[Target]) -> None:
"""Validate input dataframe columns corresponding to targets.
Args:
data: The input dataframe to be validated.
targets: The allowed targets.
Raises:
ValueError: If the data is empty.
ValueError: If the data misses columns for a target.
TypeError: If any numerical target data contain non-numeric values.
ValueError: If any binary target data contain values not part of the targets'
allowed values or NaN.
"""
from baybe.targets import BinaryTarget, NumericalTarget
if data.empty:
raise ValueError("The provided input dataframe cannot be empty.")
if missing := {t.name for t in targets}.difference(data.columns):
raise ValueError(
f"The input dataframe is missing columns for the following targets: "
f"{missing}"
)
for t in targets:
if isinstance(t, NumericalTarget):
if data[t.name].dtype.kind not in "iufb":
raise TypeError(
f"The numerical target '{t.name}' has non-numeric entries in the "
f"provided dataframe."
)
elif isinstance(t, BinaryTarget):
allowed = {t.failure_value, t.success_value, np.nan}
if invalid := set(data[t.name].unique()) - allowed:
raise ValueError(
f"The binary target '{t.name}' has invalid entries {invalid} "
f"in the provided dataframe. Allowed values are: {allowed}."
)
[docs]
def validate_objective_input(data: pd.DataFrame, objective: Objective) -> None:
"""Validate the input dataframe for the given objective.
Args:
data: The input dataframe to be validated.
objective: The objective to validate against.
Raises:
IncompleteMeasurementsError: If the objective requires complete measurements
but the input dataframe contains missing target values.
"""
data = data[[t.name for t in objective.targets]]
if not objective.supports_partial_measurements and (
cols := data.columns[data.isna().any()].tolist()
):
raise IncompleteMeasurementsError(
f"The used objective requires complete measurements of all "
f"involved targets but the provided dataframe contains missing "
f"values for the following targets: {cols}"
)
[docs]
def validate_parameter_input(
data: pd.DataFrame,
parameters: Iterable[Parameter],
numerical_measurements_must_be_within_tolerance: bool = False,
) -> None:
"""Validate input dataframe columns corresponding to parameters.
Args:
data: The input dataframe to be validated.
parameters: The allowed parameters.
numerical_measurements_must_be_within_tolerance: If ``True``, numerical
parameter values must match to parameter values within the
parameter-specific tolerance.
Raises:
ValueError: If the data is empty.
ValueError: If the data misses columns for a parameter.
ValueError: If a parameter contains NaN.
TypeError: If a parameter contains non-numeric values.
"""
if data.empty:
raise ValueError("The provided input dataframe cannot be empty.")
if missing := {p.name for p in parameters}.difference(data.columns):
raise ValueError(
f"The input dataframe is missing columns for the following parameters: "
f"{missing}"
)
for p in parameters:
if data[p.name].isna().any():
raise ValueError(
f"The parameter '{p.name}' has missing values in the provided "
f"dataframe."
)
if p.is_numerical and (data[p.name].dtype.kind not in "iufb"):
raise TypeError(
f"The numerical parameter '{p.name}' has non-numeric entries in the "
f"provided dataframe."
)
# Check if all rows have valid inputs matching allowed parameter values
for ind, row in data.iterrows():
valid = True
if p.is_numerical:
if numerical_measurements_must_be_within_tolerance:
valid &= p.is_in_range(row[p.name])
else:
valid &= p.is_in_range(row[p.name])
if not valid:
raise ValueError(
f"Input data on row with the index {row.name} has invalid "
f"values in parameter '{p.name}'. "
f"For categorical parameters, values need to exactly match a "
f"valid choice defined in your config. "
f"For numerical parameters, a match is accepted only if "
f"the input value is within the specified tolerance/range. Set "
f"the flag 'numerical_measurements_must_be_within_tolerance' "
f"to 'False' to disable this behavior."
)
[docs]
def validate_object_names(objects: Iterable[Parameter | Target], /) -> None:
"""Validate that the provided objects have unique names.
Args:
objects: An iterable containing a combination of parameters and targets.
Raises:
ValueError: If two or more objects have the same name.
"""
names = [obj.name for obj in objects]
if len(names) != len(set(names)):
duplicates = {name for name in names if names.count(name) > 1}
raise ValueError(
f"All parameters and targets must have unique names. The following names "
f"appear multiple times: {duplicates}."
)