"""Functionality for desirability objectives."""
from __future__ import annotations
import gc
from functools import cached_property
from typing import TYPE_CHECKING, ClassVar
import cattrs
import numpy as np
import pandas as pd
from attrs import define, field, fields
from attrs.validators import deep_iterable, gt, instance_of, min_len
from typing_extensions import override
from baybe.exceptions import IncompatibilityError
from baybe.objectives.base import Objective
from baybe.objectives.enum import Scalarizer
from baybe.objectives.validation import validate_target_names
from baybe.targets import NumericalTarget
from baybe.targets.base import Target
from baybe.targets.numerical import UncertainBool
from baybe.utils.basic import to_tuple
from baybe.utils.conversion import to_string
from baybe.utils.dataframe import pretty_print_df
from baybe.utils.validation import finite_float
if TYPE_CHECKING:
from botorch.acquisition.objective import MCAcquisitionObjective
from torch import Tensor
_OUTPUT_NAME = "Desirability"
"""The name of the output column produced by the desirability transform."""
def _geometric_mean(x: Tensor, /, weights: Tensor, dim: int = -1) -> Tensor:
"""Calculate the geometric mean of an array along a given dimension.
Args:
x: A tensor containing the values for the mean computation.
weights: A tensor of weights whose shape must be broadcastable to the shape
of the input tensor.
dim: The dimension along which to compute the geometric mean.
Returns:
A tensor containing the weighted geometric means.
"""
import torch
# Ensure x is a floating-point tensor
if not torch.is_floating_point(x):
x = x.float()
# Normalize weights
normalized_weights = weights / torch.sum(weights)
# Add epsilon to avoid log(0)
eps = torch.finfo(x.dtype).eps
log_tensor = torch.log(x + eps)
# Compute the weighted log sum
weighted_log_sum = torch.sum(log_tensor * normalized_weights.unsqueeze(0), dim=dim)
# Convert back from log domain
return torch.exp(weighted_log_sum)
[docs]
@define(frozen=True, slots=False)
class DesirabilityObjective(Objective):
"""An objective scalarizing multiple targets using desirability values."""
is_multi_output: ClassVar[bool] = False
# See base class.
_targets: tuple[NumericalTarget, ...] = field(
converter=to_tuple,
validator=[
min_len(2),
deep_iterable(member_validator=instance_of(NumericalTarget)),
validate_target_names,
],
alias="targets",
)
"The targets considered by the objective."
weights: tuple[float, ...] = field(
converter=lambda w: cattrs.structure(w, tuple[float, ...]),
validator=deep_iterable(member_validator=[finite_float, gt(0.0)]),
)
"""The weights to balance the different targets.
By default, all targets are considered equally important."""
scalarizer: Scalarizer = field(default=Scalarizer.GEOM_MEAN, converter=Scalarizer)
"""The mechanism to scalarize the weighted desirability values of all targets."""
require_normalization: bool = field(
default=True, validator=instance_of(bool), kw_only=True
)
"""Controls if the targets must be normalized."""
as_pre_transformation: bool = field(
default=False, validator=instance_of(bool), kw_only=True
)
"""Controls if the desirability computation is applied as a pre-transformation."""
@weights.default
def _default_weights(self) -> tuple[float, ...]:
"""Create unit weights for all targets."""
return tuple(1.0 for _ in range(len(self.targets)))
@_targets.validator
def _validate_targets(self, _, targets) -> None: # noqa: DOC101, DOC103
# Validate non-negativity when using geometric mean
if self.scalarizer is Scalarizer.GEOM_MEAN and (
negative := {t.name for t in targets if t.get_codomain().lower < 0}
):
raise ValueError(
f"Using '{Scalarizer.GEOM_MEAN}' for '{self.__class__.__name__}' "
f"requires that all targets are transformed to a non-negative range. "
f"However, the images of the following targets cover negative values: "
f"{negative}."
)
# Validate normalization
if self.require_normalization and (
unnormalized := {
t.name for t in targets if t.is_normalized is not UncertainBool.TRUE
}
):
raise ValueError(
f"By default, '{self.__class__.__name__}' only accepts normalized "
f"targets but the following targets are either not normalized or their "
f"normalization status is unclear because the image "
f"of the underlying transformation is unknown: {unnormalized}. "
f"Either normalize your targets (e.g. using their "
f"'{NumericalTarget.normalize.__name__}' method / by specifying "
f"a suitable target transformation) or explicitly set "
f"'{DesirabilityObjective.__name__}."
f"{fields(DesirabilityObjective).require_normalization.name}' to "
f"'True' to allow unnormalized targets."
)
@weights.validator
def _validate_weights(self, _, weights) -> None: # noqa: DOC101, DOC103
if (lw := len(weights)) != (lt := len(self.targets)):
raise ValueError(
f"If custom weights are specified, there must be one for each target. "
f"Specified number of targets: {lt}. Specified number of weights: {lw}."
)
@override
@property
def targets(self) -> tuple[NumericalTarget, ...]:
return self._targets
@override
@property
def _modeled_quantity_names(self) -> tuple[str, ...]:
return (
self.output_names
if self.as_pre_transformation
else tuple(t.name for t in self.targets)
)
@override
@property
def output_names(self) -> tuple[str, ...]:
return (_OUTPUT_NAME,)
@override
@property
def supports_partial_measurements(self) -> bool:
return not self.as_pre_transformation
@cached_property
def _normalized_weights(self) -> np.ndarray:
"""The normalized target weights."""
return np.asarray(self.weights) / np.sum(self.weights)
@override
def __str__(self) -> str:
targets_list = [target.summary() for target in self.targets]
targets_df = pd.DataFrame(targets_list)
targets_df["Weight"] = self.weights
fields = [
to_string("Type", self.__class__.__name__, single_line=True),
to_string("Targets", pretty_print_df(targets_df)),
to_string("Scalarizer", self.scalarizer.name, single_line=True),
]
return to_string("Objective", *fields)
@override
@property
def _oriented_targets(self) -> tuple[Target, ...]:
# For desirability, we do not only negate but also shift by 1 so that
# normalized minimization targets are still mapped to [0, 1] instead of [-1, 0]
# to enable geometric averaging.
return tuple(
t.negate() + 1 if isinstance(t, NumericalTarget) and t.minimize else t
for t in self.targets
)
@override
@property
def _full_transformation(self) -> MCAcquisitionObjective:
return self._to_botorch_full()
[docs]
@override
def to_botorch(self) -> MCAcquisitionObjective:
if self.as_pre_transformation:
return NumericalTarget(_OUTPUT_NAME).to_objective().to_botorch()
else:
return self._to_botorch_full()
def _to_botorch_full(self) -> MCAcquisitionObjective:
"""Create a BoTorch objective conducting the full desirability transform.
Full transformation means:
1. Starting from the raw target values
2. Applying the individual target transformations
3. Scalarizing the transformed values into a desirability score
This differs from the regular :meth:`to_botorch` in that the entire
transformation step is represented end-to-end by the returned objective, whereas
the former only captures the part of the transformation starting from the point
where the surrogate model(s) are applied (i.e. which may or may not include the
desirability scalarization step, depending on the chosen`as_pre_transformation`
setting).
"""
import torch
from botorch.acquisition.objective import GenericMCObjective, LinearMCObjective
from baybe.objectives.botorch import ChainedMCObjective
if self.scalarizer is Scalarizer.MEAN:
outer = LinearMCObjective(torch.tensor(self._normalized_weights))
elif self.scalarizer is Scalarizer.GEOM_MEAN:
outer = GenericMCObjective(
lambda samples, X: _geometric_mean(
samples, torch.tensor(self._normalized_weights)
)
)
else:
raise NotImplementedError(
f"No scalarization mechanism defined for '{self.scalarizer.name}'."
)
inner = super().to_botorch()
return ChainedMCObjective(inner, outer)
@override
def _pre_transform(
self,
df: pd.DataFrame,
/,
*,
allow_missing: bool = False,
allow_extra: bool = False,
) -> pd.DataFrame:
if not self.as_pre_transformation:
return super()._pre_transform(
df, allow_missing=allow_missing, allow_extra=allow_extra
)
if allow_missing:
raise IncompatibilityError(
f"Setting 'allow_missing=True' is not supported for "
f"'{self.__class__.__name__}.{self._pre_transform.__name__}' when "
f"'{fields(self.__class__).as_pre_transformation.name}=True' since "
f"the involved desirability computation requires all target columns "
f"to be present."
)
return self.transform(df, allow_missing=False, allow_extra=allow_extra)
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()