"""Base class for all priors."""fromabcimportABCfromattrsimportdefinefrombaybe.serialization.coreimport(converter,get_base_structure_hook,unstructure_base,)frombaybe.serialization.mixinimportSerialMixinfrombaybe.utils.basicimportmatch_attributes
[docs]@define(frozen=True)classPrior(ABC,SerialMixin):"""Abstract base class for all priors."""
[docs]defto_gpytorch(self,*args,**kwargs):"""Create the gpytorch representation of the prior."""importgpytorch.priorsimporttorchfrombaybe.utils.torchimportDTypeFloatTorch# TODO: This is only a temporary workaround. A proper solution requires# modifying the torch import procedure using the built-in tools of importlib# so that the dtype is set whenever torch is lazily loaded.torch.set_default_dtype(DTypeFloatTorch)prior_cls=getattr(gpytorch.priors,self.__class__.__name__)fields_dict=match_attributes(self,prior_cls.__init__)[0]# Update kwargs to contain class-specific attributeskwargs.update(fields_dict)returnprior_cls(*args,**kwargs)