Source code for baybe.utils.metadata

"""Generic metadata system for BayBE objects."""

from __future__ import annotations

from typing import Any, TypeVar

import cattrs
from attrs import AttrsInstance, define, field, fields
from attrs.validators import deep_mapping, instance_of
from attrs.validators import optional as optional_v
from typing_extensions import override

from baybe.serialization import SerialMixin, converter
from baybe.serialization.core import _TYPE_FIELD
from baybe.utils.basic import classproperty

_TMetaData = TypeVar("_TMetaData", bound="Metadata")


[docs] @define(frozen=True) class Metadata(SerialMixin): """Metadata class providing basic information for BayBE objects.""" description: str | None = field( default=None, validator=optional_v(instance_of(str)) ) """A description of the object.""" misc: dict[str, Any] = field( factory=dict, validator=deep_mapping( mapping_validator=instance_of(dict), key_validator=instance_of(str), # FIXME: https://github.com/python-attrs/attrs/issues/1246 value_validator=lambda *x: None, ), kw_only=True, ) """Additional user-defined metadata.""" @misc.validator def _validate_misc(self, _, value: dict[str, Any]) -> None: if inv := set(value).intersection(self._explicit_fields | {_TYPE_FIELD}): raise ValueError( f"Miscellaneous metadata cannot contain the following fields: {inv}. " f"Use the corresponding attributes instead." ) @classproperty def _explicit_fields(cls: type[AttrsInstance]) -> set[str]: """The explicit metadata fields.""" # noqa: D401 flds = fields(cls) return {fld.name for fld in flds if fld.name != flds.misc.name} @property def is_empty(self) -> bool: """Check if metadata contains any meaningful information.""" return self.description is None and not self.misc
[docs] @define(frozen=True) class MeasurableMetadata(Metadata): """Class providing metadata for BayBE :class:`Parameter` objects.""" unit: str | None = field(default=None, validator=optional_v(instance_of(str))) """The unit of measurement for the parameter.""" @override @property def is_empty(self) -> bool: """Check if metadata contains any meaningful information.""" return super().is_empty and self.unit is None
[docs] def to_metadata( value: dict[str, Any] | _TMetaData | None, cls: type[_TMetaData], / ) -> _TMetaData: """Convert a dictionary to :class:`Metadata` (with :class:`Metadata` passthrough). Args: value: The metadata input. cls: The specific :class:`Metadata` subclass to convert to. Returns: The created metadata instance of the requested :class:`Metadata` subclass. Raises: TypeError: If the input is not a dictionary or of the specified :class:`Metadata` type. """ if value is None: return cls() if isinstance(value, cls): return value if not isinstance(value, dict): raise TypeError( f"The input must be a dictionary or a '{cls.__name__}' instance. " f"Got: {type(value)}" ) # Separate known fields from unknown ones return converter.structure(value, cls)
ConvertibleToMeasurableMetadata = MeasurableMetadata | dict[str, Any] | None """A type alias for objects that can be converted to :class:`MeasurableMetadata`.""" @converter.register_structure_hook def _separate_metadata_fields(dct: dict[str, Any], cls: type[Metadata]) -> Metadata: """Separate known fields from miscellaneous metadata.""" dct = dct.copy() dct.pop(_TYPE_FIELD, None) explicit = {fld: dct.pop(fld, None) for fld in cls._explicit_fields} return cls(**explicit, misc=dct) @converter.register_unstructure_hook def _flatten_misc_metadata(metadata: Metadata) -> dict[str, Any]: """Flatten the metadata for serialization.""" cls = type(metadata) fn = cattrs.gen.make_dict_unstructure_fn(cls, converter) dct = fn(metadata) dct = dct | dct.pop(fields(Metadata).misc.name) return dct