Source code for baybe.transformations.utils

"""Transformation utilities."""

from __future__ import annotations

from collections.abc import Callable, Iterable
from typing import TYPE_CHECKING

import numpy as np

from baybe.transformations.base import Transformation

if TYPE_CHECKING:
    from torch import Tensor

    TensorCallable = Callable[[Tensor], Tensor]
    """Type alias for a torch-based function mapping from reals to reals."""


[docs] def convert_transformation(x: Transformation | TensorCallable, /) -> Transformation: """Autowrap a torch callable as transformation (with transformation passthrough).""" from baybe.transformations.basic import CustomTransformation return x if isinstance(x, Transformation) else CustomTransformation(x)
[docs] def combine_affine_transformations(t1, t2, /): """Combine two affine transformations into one.""" from baybe.transformations.basic import AffineTransformation factor = t2.factor * t1.factor shift = t2.factor * t1.shift + t2.shift if not np.all(np.isfinite([factor, shift])): raise OverflowError( "The combined affine transformation produces infinite values." ) return AffineTransformation(factor=factor, shift=shift)
def _flatten_transformations( transformations: Iterable[Transformation], / ) -> Iterable[Transformation]: """Recursively flatten nested chained transformations.""" from baybe.transformations.composite import ChainedTransformation for t in transformations: if isinstance(t, ChainedTransformation): yield from _flatten_transformations(t.transformations) else: yield t
[docs] def compress_transformations( transformations: Iterable[Transformation], / ) -> tuple[Transformation, ...]: """Compress any iterable of transformations by removing redundancies. Drops identity transformations and combines subsequent affine transformations. Args: transformations: An iterable of transformations. Returns: The minimum sequence of transformations that is equivalent to the input. """ from baybe.transformations.basic import AffineTransformation, IdentityTransformation aggregated: list[Transformation] = [] last = None id_ = IdentityTransformation() for t in _flatten_transformations(transformations): # Drop identity transformations (and such that are equivalent to it) if t == id_: continue # Combine subsequent affine transformations if ( aggregated and isinstance(last := aggregated.pop(), AffineTransformation) and isinstance(t, AffineTransformation) ): aggregated.append(combine_affine_transformations(last, t)) # Keep other transformations else: if last is not None: aggregated.append(last) aggregated.append(t) # Handle edge case when there was only a single identity transformation if not aggregated: return (IdentityTransformation(),) return tuple(aggregated)