Source code for baybe.kernels.basic

"""Collection of basic kernels."""

from attrs import define, field
from attrs.converters import optional as optional_c
from attrs.validators import ge, gt, in_, instance_of
from attrs.validators import optional as optional_v

from baybe.kernels.base import BasicKernel
from baybe.priors.base import Prior
from baybe.utils.conversion import fraction_to_float
from baybe.utils.validation import finite_float


[docs] @define(frozen=True) class LinearKernel(BasicKernel): """A linear kernel.""" variance_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel variance parameter.""" variance_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel variance parameter."""
[docs] def to_gpytorch(self, *args, **kwargs): # noqa: D102 # See base class. import torch from baybe.utils.torch import DTypeFloatTorch gpytorch_kernel = super().to_gpytorch(*args, **kwargs) if (initial_value := self.variance_initial_value) is not None: gpytorch_kernel.variance = torch.tensor( initial_value, dtype=DTypeFloatTorch ) return gpytorch_kernel
[docs] @define(frozen=True) class MaternKernel(BasicKernel): """A Matern kernel using a smoothness parameter.""" nu: float = field( converter=fraction_to_float, validator=in_([0.5, 1.5, 2.5]), default=2.5 ) """A smoothness parameter. Only takes the values 0.5, 1.5 or 2.5. Larger values yield smoother interpolations. """ lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale."""
[docs] @define(frozen=True) class PeriodicKernel(BasicKernel): """A periodic kernel.""" lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale.""" period_length_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel period length.""" period_length_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel period length."""
[docs] def to_gpytorch(self, *args, **kwargs): # noqa: D102 # See base class. import torch from baybe.utils.torch import DTypeFloatTorch gpytorch_kernel = super().to_gpytorch(*args, **kwargs) # lengthscale is handled by the base class if (initial_value := self.period_length_initial_value) is not None: gpytorch_kernel.period_length = torch.tensor( initial_value, dtype=DTypeFloatTorch ) return gpytorch_kernel
[docs] @define(frozen=True) class PiecewisePolynomialKernel(BasicKernel): """A piecewise polynomial kernel.""" q: int = field(validator=in_([0, 1, 2, 3]), default=2) """A smoothness parameter.""" lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale."""
[docs] @define(frozen=True) class PolynomialKernel(BasicKernel): """A polynomial kernel.""" power: int = field(validator=[instance_of(int), ge(0)]) """The power of the polynomial term.""" offset_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel offset.""" offset_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel offset."""
[docs] def to_gpytorch(self, *args, **kwargs): # noqa: D102 # See base class. import torch from baybe.utils.torch import DTypeFloatTorch gpytorch_kernel = super().to_gpytorch(*args, **kwargs) if (initial_value := self.offset_initial_value) is not None: gpytorch_kernel.offset = torch.tensor(initial_value, dtype=DTypeFloatTorch) return gpytorch_kernel
[docs] @define(frozen=True) class RBFKernel(BasicKernel): """A radial basis function (RBF) kernel.""" lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale."""
[docs] @define(frozen=True) class RFFKernel(BasicKernel): """A random Fourier features (RFF) kernel.""" num_samples: int = field(validator=[instance_of(int), ge(1)]) """The number of frequencies to draw.""" lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale."""
[docs] @define(frozen=True) class RQKernel(BasicKernel): """A rational quadratic (RQ) kernel.""" lengthscale_prior: Prior | None = field( default=None, validator=optional_v(instance_of(Prior)) ) """An optional prior on the kernel lengthscale.""" lengthscale_initial_value: float | None = field( default=None, converter=optional_c(float), validator=optional_v([finite_float, gt(0.0)]), ) """An optional initial value for the kernel lengthscale."""