Source code for baybe.transformations.base

"""Base classes for target transformations."""

from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, TypeVar

from attrs import define
from typing_extensions import override

from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import is_all_instance
from baybe.utils.dataframe import to_tensor
from baybe.utils.interval import Interval

if TYPE_CHECKING:
    from botorch.acquisition.objective import MCAcquisitionObjective
    from torch import Tensor

_TTransformation = TypeVar("_TTransformation", bound="Transformation")


def _image_equals_codomain(cls: type[_TTransformation], /) -> type[_TTransformation]:
    """Make the image of a transformation identical to its codomain."""
    cls.get_image = cls.get_codomain  # type: ignore[method-assign]
    return cls


[docs] @define(frozen=True) class Transformation(SerialMixin, ABC): """Abstract base class for all transformations.""" @abstractmethod def __call__(self, x: Tensor, /) -> Tensor: """Transform a given input tensor."""
[docs] @abstractmethod def get_codomain(self, interval: Interval | None = None, /) -> Interval: """Get the codomain of a certain interval (assuming transformation continuity). In accordance with the mathematical definition of a function's `codomain <https://en.wikipedia.org/wiki/Codomain>`_, we define the codomain of a given :class:`~baybe.utils.interval.Interval` under a certain (assumed continuous) :class:`~Transformation` to be an :class:`~baybe.utils.interval.Interval` guaranteed to contain all possible outcomes when the :class:`~Transformation` is applied to all points in the input :class:`~baybe.utils.interval.Interval`. In cases where the image cannot exactly be computed, it is often still possible to compute a codomain. The codomain always contains the image, but might be larger. """
[docs] def get_image(self, interval: Interval | None = None, /) -> Interval: """Get the image of a certain interval (assuming transformation continuity). In accordance with the mathematical definition of a function's `image <https://en.wikipedia.org/wiki/Image_(mathematics)>`_, we define the image of a given :class:`~baybe.utils.interval.Interval` under a certain (assumed continuous) :class:`~Transformation` to be the smallest :class:`~baybe.utils.interval.Interval` containing all possible outcomes when the :class:`~Transformation` is applied to all points in the input :class:`~baybe.utils.interval.Interval`. """ # By default, it is assumed that the exact image of an interval cannot be # computed but only the codomain is available (see :meth:`get_codomain`). # Transformations that can provide the exact image should override this method. raise NotImplementedError( f"The exact image of the interval cannot be computed. " f"If sufficient, use '{self.get_codomain.__name__}' instead." )
[docs] def to_botorch_objective(self) -> MCAcquisitionObjective: """Convert to BoTorch objective.""" from botorch.acquisition.objective import GenericMCObjective return GenericMCObjective(lambda samples, X: self(samples))
[docs] def chain(self, transformation: Transformation, /) -> Transformation: """Chain another transformation with the existing one.""" return self | transformation
[docs] def negate(self) -> Transformation: """Negate the output of the transformation.""" from baybe.transformations.basic import AffineTransformation return self | AffineTransformation(factor=-1)
[docs] def clamp( self, min: float = float("-inf"), max: float = float("inf") ) -> Transformation: """Clamp the output of the transformation.""" if min == float("-inf") and max == float("inf"): raise ValueError( "A clamping transformation requires at least one finite boundary value." ) from baybe.transformations.basic import ClampingTransformation return self | ClampingTransformation(min, max)
[docs] def abs(self) -> Transformation: """Take the absolute value of the output of the transformation.""" from baybe.transformations.basic import AbsoluteTransformation return self | AbsoluteTransformation()
def __neg__(self) -> Transformation: return self.negate() def __add__(self, other: Any) -> Transformation: """Add a constant or the output of another transformation.""" if isinstance(other, Transformation): from baybe.transformations import AdditiveTransformation return AdditiveTransformation([self, other]) if isinstance(other, (int, float)): from baybe.transformations import AffineTransformation return self | AffineTransformation(shift=other) return NotImplemented def __sub__(self, other: Any) -> Transformation: """Subtract a constant from the output of the transformation.""" if isinstance(other, Transformation): from baybe.transformations import AdditiveTransformation return AdditiveTransformation([self, -other]) if isinstance(other, (int, float)): from baybe.transformations import AffineTransformation return self | AffineTransformation(shift=-other) return NotImplemented def __mul__(self, other: Any) -> Transformation: """Multiply with a constant or the output of another transformation.""" if isinstance(other, Transformation): from baybe.transformations import MultiplicativeTransformation return MultiplicativeTransformation([self, other]) if isinstance(other, (int, float)): from baybe.transformations import AffineTransformation return self | AffineTransformation(factor=other) return NotImplemented def __truediv__(self, other: Any) -> Transformation: """Divide the output of the transformation by a constant.""" if isinstance(other, (int, float)): from baybe.transformations import AffineTransformation if other == 0: raise ValueError("Division by zero is not allowed.") return self | AffineTransformation(factor=1 / other) return NotImplemented def __or__(self, other: Any) -> Transformation: """Chain the transformation with another one. Inspired by the Unix "pipe".""" from baybe.transformations import ( AffineTransformation, ChainedTransformation, IdentityTransformation, combine_affine_transformations, ) if isinstance(other, IdentityTransformation): return self if is_all_instance(t := [self, other], AffineTransformation): return combine_affine_transformations(*t) if isinstance(other, Transformation): return ChainedTransformation([self, other]) if callable(other): from baybe.transformations.basic import CustomTransformation return self | CustomTransformation(other) return NotImplemented @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): """Chain the transformation with a given torch callable.""" if not ( len(args) == 1 and isinstance(args[0], Transformation) and kwargs is None ): raise ValueError( "Composing transformations with torch operations is only supported " "if the transformation enters as the only (positional) argument." ) from baybe.transformations.basic import CustomTransformation return args[0] | CustomTransformation(func)
[docs] @_image_equals_codomain @define(frozen=True) class MonotonicTransformation(Transformation, ABC): """Abstract base class for monotonic transformations."""
[docs] @override def get_codomain(self, interval: Interval | None = None, /) -> Interval: interval = Interval.create(interval) return Interval( *sorted( [ self(to_tensor(interval.lower)).item(), self(to_tensor(interval.upper)).item(), ] ) )