Source code for baybe.priors.base

"""Base class for all priors."""

import gc
from abc import ABC

from attrs import define

from baybe.serialization.mixin import SerialMixin
from baybe.utils.basic import match_attributes


[docs] @define(frozen=True) class Prior(ABC, SerialMixin): """Abstract base class for all priors."""
[docs] def to_gpytorch(self, *args, **kwargs): """Create the gpytorch representation of the prior.""" import gpytorch.priors import torch from baybe.utils.torch import DTypeFloatTorch # TODO: This is only a temporary workaround. A proper solution requires # modifying the torch import procedure using the built-in tools of importlib # so that the dtype is set whenever torch is lazily loaded. torch.set_default_dtype(DTypeFloatTorch) prior_cls = getattr(gpytorch.priors, self.__class__.__name__) fields_dict = match_attributes(self, prior_cls.__init__)[0] # Update kwargs to contain class-specific attributes kwargs.update(fields_dict) return prior_cls(*args, **kwargs)
# Collect leftover original slotted classes processed by `attrs.define` gc.collect()