"""Base classes for all objectives."""
from __future__ import annotations
import gc
import warnings
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, ClassVar
import pandas as pd
from attrs import define, field
from baybe.serialization.mixin import SerialMixin
from baybe.targets.base import Target
from baybe.targets.numerical import NumericalTarget
from baybe.utils.basic import is_all_instance
from baybe.utils.dataframe import get_transform_objects, to_tensor
from baybe.utils.dataframe import (
handle_missing_values as df_handle_missing_values,
)
from baybe.utils.metadata import Metadata, to_metadata
from baybe.utils.validation import validate_target_input
if TYPE_CHECKING:
from botorch.acquisition.objective import MCAcquisitionObjective, PosteriorTransform
# TODO: Reactive slots in all classes once cached_property is supported:
# https://github.com/python-attrs/attrs/issues/164
[docs]
@define(frozen=True, slots=False)
class Objective(ABC, SerialMixin):
"""Abstract base class for all objectives."""
is_multi_output: ClassVar[bool]
"""Class variable indicating if the objective produces multiple outputs."""
metadata: Metadata = field(
factory=Metadata,
converter=lambda x: to_metadata(x, Metadata),
kw_only=True,
)
"""Optional metadata containing description and other information."""
@property
def description(self) -> str | None:
"""The description of the objective."""
return self.metadata.description
@property
@abstractmethod
def targets(self) -> tuple[Target, ...]:
"""The targets included in the objective."""
@property
def _modeled_quantities(self) -> tuple[Target, ...]:
"""The quantities modeled by this objective."""
return self.targets
@property
def _modeled_quantity_names(self) -> tuple[str, ...]:
"""The names of the quantities returned by the pre-transformation."""
return tuple(t.name for t in self._modeled_quantities)
@property
def _model_quantities_to_target_names(self) -> dict[str, list[str]]:
"""The mapping from modeled quantity names to names of the required targets."""
return {mq.name: [mq.name] for mq in self._modeled_quantities}
@property
def _n_models(self) -> int:
"""The number of models used in the objective.
Corresponds to the number of dimensions after the pre-transformation.
"""
return len(self._modeled_quantities)
@property
def _is_multi_model(self) -> bool:
"""Check if the objective relies on multiple surrogate models."""
return self._n_models > 1
@property
@abstractmethod
def output_names(self) -> tuple[str, ...]:
"""The names of the outputs of the objective."""
@property
def n_outputs(self) -> int:
"""The number of outputs of the objective."""
return len(self.output_names)
@property
@abstractmethod
def supports_partial_measurements(self) -> bool:
"""Boolean indicating if the objective accepts partial target measurements."""
@property
def _oriented_targets(self) -> tuple[Target, ...]:
"""The targets with optional negation transformation for minimization."""
return tuple(
t.negate() if isinstance(t, NumericalTarget) and t.minimize else t
for t in self.targets
)
@property
def _full_transformation(self) -> MCAcquisitionObjective:
"""The end-to-end transformation applied, from targets to objective values."""
return self.to_botorch()
[docs]
def handle_missing_values(
self, measurements: pd.DataFrame
) -> dict[str, pd.DataFrame]:
"""Handle missing values in the given measurements for each modeled quantity.
Args:
measurements: Data potentially containing missing values.
Returns:
A dictionary with one dataframe for each modeled quantity.
"""
cleaned: dict[str, pd.DataFrame] = {}
for quantity, target_names in self._model_quantities_to_target_names.items():
data = df_handle_missing_values(measurements, target_names, drop=True)
cleaned[quantity] = data
return cleaned
[docs]
def to_botorch(self) -> MCAcquisitionObjective:
"""Convert to BoTorch objective."""
if not is_all_instance(targets := self._oriented_targets, NumericalTarget):
raise NotImplementedError(
"Conversion to BoTorch is only supported for numerical targets."
)
import torch
from botorch.acquisition.multi_objective.objective import (
GenericMCMultiOutputObjective,
)
return GenericMCMultiOutputObjective(
lambda samples, X: torch.stack(
[
t.transformation.to_botorch_objective()(samples[..., i])
for i, t in enumerate(targets)
],
dim=-1,
)
)
[docs]
@abstractmethod
def to_botorch_posterior_transform(self) -> PosteriorTransform:
"""Convert to BoTorch posterior transform, if possible.
A representation as posterior transformation is only possible if Gaussianity
is preserved by the involved operations, that is, when all targets are
inherently numerical and their assigned transformations are affine.
"""
def _pre_transform(
self,
df: pd.DataFrame,
/,
*,
allow_missing: bool = False,
allow_extra: bool = False,
) -> pd.DataFrame:
"""Pre-transform the target values prior to predictive modeling.
For details on the method arguments, see :meth:`transform`.
"""
# By default, we just pipe through the unmodified target values
targets = get_transform_objects(
df, self.targets, allow_missing=allow_missing, allow_extra=allow_extra
)
return df[[t.name for t in targets]]
[docs]
def identify_non_dominated_configurations(
self, configurations: pd.DataFrame, /
) -> pd.Series:
"""Create a Boolean mask indicating non-dominated target configurations.
In case of duplicated non-dominated points, all duplicates are marked as
non-dominated.
Note:
Non-dominated configurations can be computed for any objective type, not
just for :class:`~baybe.objectives.pareto.ParetoObjective`.
For more details, have a look at the corresponding
:ref:`user guide section <userguide/objectives:Identifying Non-Dominated
Configurations>`.
Args:
configurations: The target configurations for which the non-dominated subset
is identified.
Returns:
A Boolean series indicating which configurations are non-dominated.
"""
from botorch.utils.multi_objective.pareto import is_non_dominated
validate_target_input(configurations, self.targets)
targets = self.transform(configurations)
non_dominated = is_non_dominated(Y=to_tensor(targets), deduplicate=False)
return pd.Series(non_dominated.numpy(), name="is_non_dominated")
[docs]
def to_objective(x: Target | Objective, /) -> Objective:
"""Convert a target into an objective (with objective passthrough)."""
return x if isinstance(x, Objective) else x.to_objective()
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()