"""Utilities for handling intervals."""fromcollections.abcimportIterablefromfunctoolsimportsingledispatchmethodfromtypingimportTYPE_CHECKING,Anyimportnumpyasnpfromattrsimportdefine,fieldfrombaybe.serializationimportSerialMixin,converterfrombaybe.utils.numericalimportDTypeFloatNumpyifTYPE_CHECKING:fromtorchimportTensor# TODO[typing]: Add return type hints to classmethod constructors once ForwardRefs# are supported: https://bugs.python.org/issue41987
[docs]classInfiniteIntervalError(Exception):"""An interval that should be finite is infinite."""
[docs]@defineclassInterval(SerialMixin):"""Intervals on the real number line."""lower:float=field(converter=lambdax:float(x)ifxisnotNoneelse-np.inf)"""The lower end of the interval."""upper:float=field(converter=lambdax:float(x)ifxisnotNoneelsenp.inf)"""The upper end of the interval."""@upper.validatordef_validate_order(self,_:Any,upper:float)->None:# noqa: DOC101, DOC103"""Validate the order of the interval bounds. Raises: ValueError: If the upper end is not larger than the lower end. """ifupper<self.lower:raiseValueError(f"The upper interval bound (provided value: {upper}) cannot be smaller "f"than the lower bound (provided value: {self.lower}).")@propertydefis_degenerate(self)->bool:"""Check if the interval is degenerate (i.e., contains only a single number)."""returnself.lower==self.upper@propertydefis_bounded(self)->bool:"""Check if the interval is bounded."""returnself.is_left_boundedandself.is_right_bounded@propertydefis_left_bounded(self)->bool:"""Check if the interval is left-bounded."""returnnp.isfinite(self.lower)@propertydefis_right_bounded(self)->bool:"""Check if the interval is right-bounded."""returnnp.isfinite(self.upper)@propertydefis_half_bounded(self)->bool:"""Check if the interval is half-bounded."""returnself.is_left_bounded^self.is_right_bounded@propertydefis_fully_unbounded(self)->bool:"""Check if the interval represents the entire real number line."""returnnot(self.is_left_boundedorself.is_right_bounded)@propertydefcenter(self)->float|None:"""The center of the interval, or ``None`` if the interval is unbounded."""ifnotself.is_bounded:returnNonereturn(self.lower+self.upper)/2
[docs]@singledispatchmethod@classmethoddefcreate(cls,value:Any):"""Create an interval from various input types."""raiseNotImplementedError(f"Unsupported argument type: {type(value)}")
@create.register@classmethoddef_(cls,_:None):"""Overloaded implementation for creating an empty interval."""returnInterval(-np.inf,np.inf)@create.register@classmethoddef_(cls,bounds:Iterable):"""Overloaded implementation for creating an interval of an iterable."""returnInterval(*bounds)
[docs]defto_tuple(self)->tuple[float,float]:"""Transform the interval to a tuple."""returnself.lower,self.upper
[docs]defto_ndarray(self)->np.ndarray:"""Transform the interval to a :class:`numpy.ndarray`."""returnnp.array([self.lower,self.upper],dtype=DTypeFloatNumpy)
[docs]defto_tensor(self)->"Tensor":"""Transform the interval to a :class:`torch.Tensor`."""importtorchfrombaybe.utils.torchimportDTypeFloatTorchreturntorch.tensor([self.lower,self.upper],dtype=DTypeFloatTorch)
[docs]defcontains(self,number:float)->bool:"""Check whether the interval contains a given number. Args: number: The number that should be checked. Returns: Whether or not the interval contains the number. """returnself.lower<=number<=self.upper
[docs]defconvert_bounds(bounds:None|Iterable|Interval)->Interval:"""Convert bounds given in another format to an interval. Args: bounds: The bounds that should be transformed to an interval. Returns: The interval. """ifisinstance(bounds,Interval):returnboundsreturnInterval.create(bounds)
[docs]defuse_fallback_constructor_hook(value:Any,cls:type[Interval])->Interval:"""Use the singledispatch mechanism as fallback to parse arbitrary input."""ifisinstance(value,dict):returnconverter.structure_attrs_fromdict(value,cls)returnInterval.create(value)