"""Base classes for all kernels."""from__future__importannotationsfromabcimportABCfromtypingimportTYPE_CHECKING,Anyfromattrsimportdefinefrombaybe.exceptionsimportUnmatchedAttributeErrorfrombaybe.priors.baseimportPriorfrombaybe.serialization.coreimport(converter,get_base_structure_hook,unstructure_base,)frombaybe.serialization.mixinimportSerialMixinfrombaybe.utils.basicimportget_baseclasses,match_attributesifTYPE_CHECKING:importtorchfrombaybe.surrogates.gaussian_process.kernel_factoryimportPlainKernelFactory
[docs]@define(frozen=True)classKernel(ABC,SerialMixin):"""Abstract base class for all kernels."""
[docs]defto_factory(self)->PlainKernelFactory:"""Wrap the kernel in a :class:`baybe.surrogates.gaussian_process.kernel_factory.PlainKernelFactory`."""# noqa: E501frombaybe.surrogates.gaussian_process.kernel_factoryimportPlainKernelFactoryreturnPlainKernelFactory(self)
[docs]defto_gpytorch(self,*,ard_num_dims:int|None=None,batch_shape:torch.Size|None=None,active_dims:tuple[int,...]|None=None,):"""Create the gpytorch representation of the kernel."""importgpytorch.kernels# Extract keywords with non-default values. This is required since gpytorch# makes use of kwargs, i.e. differentiates if certain keywords are explicitly# passed or not. For instance, `ard_num_dims = kwargs.get("ard_num_dims", 1)`# fails if we explicitly pass `ard_num_dims=None`.kw:dict[str,Any]=dict(ard_num_dims=ard_num_dims,batch_shape=batch_shape,active_dims=active_dims)kw={k:vfork,vinkw.items()ifvisnotNone}# Get corresponding gpytorch kernel class and its base classeskernel_cls=getattr(gpytorch.kernels,self.__class__.__name__)base_classes=get_baseclasses(kernel_cls,abstract=True)# Fetch the necessary gpytorch constructor parameters of the kernel.# NOTE: In gpytorch, some attributes (like the kernel lengthscale) are handled# via the `gpytorch.kernels.Kernel` base class. Hence, it is not sufficient to# just check the fields of the actual class, but also those of its base# classes.kernel_attrs:dict[str,Any]={}unmatched_attrs:dict[str,Any]={}forclsin[kernel_cls,*base_classes]:matched,unmatched=match_attributes(self,cls.__init__,strict=False)kernel_attrs.update(matched)unmatched_attrs.update(unmatched)# Sanity check: all attributes of the BayBE kernel need a corresponding match# in the gpytorch kernel (otherwise, the BayBE kernel class is misconfigured).# Exception: initial values are not used during construction but are set# on the created object (see code at the end of the method).missing=set(unmatched)-set(kernel_attrs)ifleftover:={mforminmissingifnotm.endswith("_initial_value")}:raiseUnmatchedAttributeError(leftover)# Convert specified priors to gpytorch, if providedprior_dict={key:value.to_gpytorch()forkey,valueinkernel_attrs.items()ifisinstance(value,Prior)}# Convert specified inner kernels to gpytorch, if providedkernel_dict={key:value.to_gpytorch(**kw)forkey,valueinkernel_attrs.items()ifisinstance(value,Kernel)}# Create the kernel with all its inner gpytorch objectskernel_attrs.update(kernel_dict)kernel_attrs.update(prior_dict)gpytorch_kernel=kernel_cls(**kernel_attrs,**kw)# If the kernel has a lengthscale, set its initial valueifkernel_cls.has_lengthscale:importtorchfrombaybe.utils.torchimportDTypeFloatTorch# We can ignore mypy here and simply assume that the corresponding BayBE# kernel class has the necessary lengthscale attribute defined. This is# safer than using a `hasattr` check in the above if-condition since for# the latter the code would silently fail when forgetting to add the# attribute to a new kernel class / misspelling it.if(initial_value:=self.lengthscale_initial_value)isnotNone:# type: ignore[attr-defined]gpytorch_kernel.lengthscale=torch.tensor(initial_value,dtype=DTypeFloatTorch)returngpytorch_kernel
[docs]@define(frozen=True)classBasicKernel(Kernel,ABC):"""Abstract base class for all basic kernels."""
[docs]@define(frozen=True)classCompositeKernel(Kernel,ABC):"""Abstract base class for all composite kernels."""