Source code for baybe.priors.basic
"""A collection of common prior distributions."""
from __future__ import annotations
import gc
from typing import Any
from attrs import define, field
from attrs.validators import gt
from typing_extensions import override
from baybe.priors.base import Prior
from baybe.utils.validation import finite_float
[docs]
@define(frozen=True)
class GammaPrior(Prior):
"""A Gamma prior parameterized by concentration and rate."""
concentration: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The concentration."""
rate: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The rate."""
[docs]
@define(frozen=True)
class HalfCauchyPrior(Prior):
"""A Half-Cauchy prior parameterized by a scale."""
scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale."""
[docs]
@define(frozen=True)
class NormalPrior(Prior):
"""A Normal prior parameterized by location and scale."""
loc: float = field(converter=float, validator=finite_float)
"""The location (mu)."""
scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""
[docs]
@define(frozen=True)
class HalfNormalPrior(Prior):
"""A Half-Normal prior parameterized by a scale."""
scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""
[docs]
@define(frozen=True)
class LogNormalPrior(Prior):
"""A Log-Normal prior parameterized by location and scale."""
loc: float = field(converter=float, validator=finite_float)
"""The location (mu)."""
scale: float = field(converter=float, validator=[finite_float, gt(0.0)])
"""The scale (sigma)."""
[docs]
@define(frozen=True)
class SmoothedBoxPrior(Prior):
"""A Smoothed-Box prior parameterized by a, b and sigma."""
a: float = field(converter=float, validator=finite_float)
"""The left/lower bound."""
b: float = field(converter=float, validator=finite_float)
"""The right/upper bound."""
sigma: float = field(
converter=float, default=0.01, validator=[finite_float, gt(0.0)]
)
"""The scale."""
@b.validator
def _validate_order(self, _: Any, b: float) -> None: # noqa: DOC101, DOC103
"""Validate the order of both bounds.
Raises:
ValueError: If b is not larger than a.
"""
if b <= self.a:
raise ValueError(
f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) "
f"must be larger than the lower bound `a` (provided: {self.a})."
)
[docs]
@define(frozen=True)
class BetaPrior(Prior):
"""A beta prior parameterized by alpha and beta."""
alpha: float = field(converter=float, validator=gt(0.0))
"""Alpha concentration parameter. Controls mass accumulated toward zero."""
beta: float = field(converter=float, validator=gt(0.0))
"""Beta concentration parameter. Controls mass accumulated toward one."""
[docs]
@override
def to_gpytorch(self, *args, **kwargs):
raise NotImplementedError(
f"'{self.__class__.__name__}' does not have a gpytorch analog."
)
# Collect leftover original slotted classes processed by `attrs.define`
gc.collect()