Source code for baybe.utils.basic

"""Collection of small basic utilities."""

import functools
import inspect
from collections.abc import Callable, Collection, Iterable, Sequence
from dataclasses import dataclass
from typing import Any, TypeVar

from attrs import asdict, has

from baybe.exceptions import UnidentifiedSubclassError, UnmatchedAttributeError

_C = TypeVar("_C", bound=type)
_T = TypeVar("_T")
_U = TypeVar("_U")


[docs] @dataclass(frozen=True, repr=False) class Dummy: """Placeholder element for array-like data types. Useful e.g. for detecting duplicates in constraints. """ def __repr__(self): """Return a representation of the placeholder.""" return "<dummy>"
[docs] def get_subclasses(cls: _C, recursive: bool = True, abstract: bool = False) -> list[_C]: """Return a list of subclasses for the given class. Args: cls: The base class to retrieve subclasses for. recursive: If ``True``, indirect subclasses (i.e. subclasses of subclasses) are included. abstract: If ``True``, abstract subclasses are included. Returns: A list of subclasses for the given class. """ from baybe.utils.boolean import is_abstract subclasses = [] for subclass in cls.__subclasses__(): # Append direct subclass only if it is not abstract if abstract or not is_abstract(subclass): subclasses.append(subclass) # If requested, add indirect subclasses if recursive: subclasses.extend(get_subclasses(subclass, abstract=abstract)) return subclasses
[docs] def get_baseclasses( cls: type, recursive: bool = True, abstract: bool = False, ) -> list[type]: """Return a list of base classes for the given class. Args: cls: The class to retrieve base classes for. recursive: If ``True``, indirect base classes (i.e., base classes of base classes) are included. abstract: If `True`, abstract base classes are included. Returns: A list of base classes for the given class. """ from baybe.utils.boolean import is_abstract classes = [] for baseclass in cls.__bases__: if baseclass not in classes: if abstract or not is_abstract(baseclass): classes.append(baseclass) if recursive: classes.extend(get_baseclasses(baseclass, abstract=abstract)) return classes
[docs] def hilberts_factory(factory: Callable[..., _T]) -> Iterable[_T]: """Provide an infinite stream of the factory's products.""" while True: yield factory()
[docs] def group_duplicate_values(dictionary: dict[_T, _U]) -> dict[_U, list[_T]]: """Identify groups of keys that have the same value. Args: dictionary: The dictionary to screen for duplicate values. Returns: A dictionary whose keys are a subset of values of the input dictionary, and whose values are lists that group original keys holding the same value. Example: >>> group_duplicate_values({"A": 1, "B": 2, "C": 1, "D": 3}) {1: ['A', 'C']} """ group: dict[_U, list[_T]] = {} for key, value in dictionary.items(): group.setdefault(value, []).append(key) return {k: v for k, v in group.items() if len(v) > 1}
[docs] def to_tuple(x: Sequence) -> tuple: """Convert any sequence into a tuple. This wrapper is only used to avoid mypy warnings for attrs converters: * https://github.com/python/mypy/issues/8417 * https://github.com/python/mypy/issues/8389 * https://github.com/python/mypy/issues/5313 """ # TODO: Remove wrapper once mypy support is there return tuple(x)
[docs] def match_attributes( object: Any, callable_: Callable, /, strict: bool = True, ignore: Collection[str] = (), ) -> tuple[dict[str, Any], dict[str, Any]]: """Find the attributes of an object that match with a given callable signature. Parameters appearing in the callable signature that have no match with the given object attributes are ignored. Args: object: The object whose attributes are to be matched. callable_: The callable against whose signature the attributes are to be matched. strict: If ``True``, an error is raised when the object has attributes that are not found in the callable signature (see also ``ignore``). If ``False``, these attributes are returned separately. ignore: A collection of attributes names that are to be ignored during matching. Raises: ValueError: If applied to a non-attrs object. UnmatchedAttributeError: In strict mode, if not all attributes can be matched. Returns: * A dictionary mapping the matched attribute names to their values. * A dictionary mapping the unmatched attribute names to their values. """ if not has(object.__class__): raise ValueError( f"'{match_attributes.__name__}' only works with attrs objects." ) # Get attribute/parameter sets set_object = set(asdict(object)) - set(ignore) set_callable = set(inspect.signature(callable_).parameters) # Match in_signature = set_object.intersection(set_callable) not_in_signature = set_object - set_callable if strict and not_in_signature: raise UnmatchedAttributeError( f"The following attributes cannot be matched: {not_in_signature}." ) # Collect attributes for both sets attrs_in_signature = {p: getattr(object, p) for p in in_signature} attrs_not_in_signature = {p: getattr(object, p) for p in not_in_signature} return attrs_in_signature, attrs_not_in_signature
[docs] class classproperty: """A decorator to make class properties. A class property combines the characteristics of @property and @classmethod. The simple chaining of these two existing decorators is deprecated (https://docs.python.org/3.11/whatsnew/3.11.html#language-builtins) and causes mypy issues. """
[docs] def __init__(self, fn: Callable) -> None: self.fn = fn
def __get__(self, _, cl: type): return self.fn(cl)
[docs] def refers_to(cls: type, name_or_abbr: str, /) -> bool: """Check if the given name or abbreviation refers to the specified class.""" return name_or_abbr in ( (cls.__name__, cls.abbreviation) if hasattr(cls, "abbreviation") else (cls.__name__,) )
[docs] def find_subclass(base: type, name_or_abbr: str, /): """Retrieve a specific subclass of a base class via its name or abbreviation.""" try: return next(cl for cl in get_subclasses(base) if refers_to(cl, name_or_abbr)) except StopIteration: raise UnidentifiedSubclassError( f"The class name or abbreviation '{name_or_abbr}' does not refer to any " f"of the subclasses of '{base.__name__}'." )
[docs] def register_hooks( target: Callable, pre_hooks: Sequence[Callable] | None = None, post_hooks: Sequence[Callable] | None = None, ) -> Callable: """Register custom hooks with a given target callable. The provided hooks need to be "compatible" with the target in the sense that their signatures can be aligned: * The hook signature may only contain parameters that also exist in the target callable (<-- basic requirement). * However, parameters that are not needed by the hook can be omitted from its signature. This requires that the parameters of both signatures can be matched via their names. For simplicity, it is thus assumed that the hook has no positional-only arguments. * If an annotation is provided for a hook parameter, it must match its target counterpart (<-- safety mechanism to prevent unintended argument use). An exception is when the target parameter has no annotation, in which case the hook annotation is unrestricted. This is particularly useful when registering hooks with methods, since it offers the possibility to annotate the "self" parameter bound to the method-carrying object, which is typically not annotated in the target callable. Args: target: The callable to which the hooks are to be attached. pre_hooks: Hooks to be executed before calling the target. post_hooks: Hooks to be executed after calling the target. Returns: The wrapped callable with the hooks attached. Raises: TypeError: If any hook has positional-only arguments. TypeError: If any hook expects parameters that are not present in the target. TypeError: If any hook has a non-empty parameter annotation that does not match with the corresponding annotation of the target. """ # Defaults pre_hooks = pre_hooks or [] post_hooks = post_hooks or [] target_signature = inspect.signature(target, eval_str=True).parameters # Validate hook signatures for hook in [*pre_hooks, *post_hooks]: hook_signature = inspect.signature(hook, eval_str=True).parameters if any( p.kind is inspect.Parameter.POSITIONAL_ONLY for p in hook_signature.values() ): raise TypeError("The provided hooks cannot have positional-only arguments.") if unrecognized := (set(hook_signature) - set(target_signature)): raise TypeError( f"The parameters expected by the hook '{hook.__name__}' must be a " f"subset of the parameter of the target callable '{target.__name__}'. " f"Unrecognized hook parameters: {unrecognized}." ) for name, hook_param in hook_signature.items(): target_param = target_signature[name] # If target parameter is not annotated, the hook annotation is unrestricted if (t_hint := target_param.annotation) is inspect.Parameter.empty: continue # If target parameter is annotated, the hook annotation must be compatible, # i.e., be identical or empty if ((h_hint := hook_param.annotation) != t_hint) and ( h_hint is not inspect.Parameter.empty ): raise TypeError( f"The type annotation for '{name}' is not consistent between " f"the given hook '{hook.__name__}' and the target callable " f"'{target.__name__}'. Given: {h_hint}. Expected: {t_hint}." ) def pass_args(hook: Callable, *args, **kwargs) -> None: """Call the hook with its requested subset of arguments.""" hook_signature = inspect.signature(hook, eval_str=True).parameters matched_args = dict(zip(target_signature, args)) matched_kwargs = { p: kwargs.get(p, target_signature[p].default) for p in hook_signature if p not in matched_args } passed_kwargs = {p: (matched_args | matched_kwargs)[p] for p in hook_signature} hook(**passed_kwargs) @functools.wraps(target) def wraps(*args, **kwargs): for hook in pre_hooks: pass_args(hook, *args, **kwargs) result = target(*args, **kwargs) for hook in post_hooks: pass_args(hook, *args, **kwargs) return result return wraps