"""A collection of common prior distributions."""fromtypingimportAnyfromattrsimportdefine,fieldfromattrs.validatorsimportgtfrombaybe.priors.baseimportPriorfrombaybe.utils.validationimportfinite_float
[docs]@define(frozen=True)classGammaPrior(Prior):"""A Gamma prior parameterized by concentration and rate."""concentration:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The concentration."""rate:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The rate."""
[docs]@define(frozen=True)classHalfCauchyPrior(Prior):"""A Half-Cauchy prior parameterized by a scale."""scale:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The scale."""
[docs]@define(frozen=True)classNormalPrior(Prior):"""A Normal prior parameterized by location and scale."""loc:float=field(converter=float,validator=finite_float)"""The location (mu)."""scale:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The scale (sigma)."""
[docs]@define(frozen=True)classHalfNormalPrior(Prior):"""A Half-Normal prior parameterized by a scale."""scale:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The scale (sigma)."""
[docs]@define(frozen=True)classLogNormalPrior(Prior):"""A Log-Normal prior parameterized by location and scale."""loc:float=field(converter=float,validator=finite_float)"""The location (mu)."""scale:float=field(converter=float,validator=[finite_float,gt(0.0)])"""The scale (sigma)."""
[docs]@define(frozen=True)classSmoothedBoxPrior(Prior):"""A Smoothed-Box prior parameterized by a, b and sigma."""a:float=field(converter=float,validator=finite_float)"""The left/lower bound."""b:float=field(converter=float,validator=finite_float)"""The right/upper bound."""sigma:float=field(converter=float,default=0.01,validator=[finite_float,gt(0.0)])"""The scale."""@b.validatordef_validate_order(self,_:Any,b:float)->None:# noqa: DOC101, DOC103"""Validate the order of both bounds. Raises: ValueError: If b is not larger than a. """ifb<=self.a:raiseValueError(f"For {self.__class__.__name__}, the upper bound `b` (provided: {b}) "f"must be larger than the lower bound `a` (provided: {self.a}).")