"""Functionality for desirability objectives."""fromcollections.abcimportCallablefromfunctoolsimportcached_property,partialfromtypingimportTypeGuardimportcattrsimportnumpyasnpimportnumpy.typingasnptimportpandasaspdfromattrsimportdefine,fieldfromattrs.validatorsimportdeep_iterable,gt,instance_of,min_lenfrombaybe.objectives.baseimportObjectivefrombaybe.objectives.enumimportScalarizerfrombaybe.targets.baseimportTargetfrombaybe.targets.numericalimportNumericalTargetfrombaybe.utils.basicimportto_tuplefrombaybe.utils.numericalimportgeom_meanfrombaybe.utils.validationimportfinite_floatdef_is_all_numerical_targets(x:tuple[Target,...],/)->TypeGuard[tuple[NumericalTarget,...]]:"""Typeguard helper function."""returnall(isinstance(y,NumericalTarget)foryinx)
[docs]defscalarize(values:npt.ArrayLike,scalarizer:Scalarizer,weights:npt.ArrayLike)->np.ndarray:"""Scalarize the rows of a 2-D array, producing a 1-D array. Args: values: The 2-D array whose rows are to be scalarized. scalarizer: The scalarization mechanism to be used. weights: Weights for the columns of the input array. Raises: ValueError: If the provided array is not two-dimensional. NotImplementedError: If the requested scalarizer is not implemented. Returns: np.ndarray: A 1-D array containing the scalarized values. """ifnp.ndim(values)!=2:raiseValueError("The provided array must be two-dimensional.")func:CallableifscalarizerisScalarizer.GEOM_MEAN:func=geom_meanelifscalarizerisScalarizer.MEAN:func=partial(np.average,axis=1)else:raiseNotImplementedError(f"No scalarization mechanism defined for '{scalarizer.name}'.")returnfunc(values,weights=weights)
[docs]@define(frozen=True,slots=False)classDesirabilityObjective(Objective):"""An objective scalarizing multiple targets using desirability values."""_targets:tuple[Target,...]=field(converter=to_tuple,validator=[min_len(2),deep_iterable(member_validator=instance_of(Target))],# type: ignore[type-abstract]alias="targets",)"The targets considered by the objective."weights:tuple[float,...]=field(converter=lambdaw:cattrs.structure(w,tuple[float,...]),validator=deep_iterable(member_validator=[finite_float,gt(0.0)]),)"""The weights to balance the different targets. By default, all targets are considered equally important."""scalarizer:Scalarizer=field(default=Scalarizer.GEOM_MEAN,converter=Scalarizer)"""The mechanism to scalarize the weighted desirability values of all targets."""@weights.defaultdef_default_weights(self)->tuple[float,...]:"""Create unit weights for all targets."""returntuple(1.0for_inrange(len(self.targets)))@_targets.validatordef_validate_targets(self,_,targets)->None:# noqa: DOC101, DOC103ifnot_is_all_numerical_targets(targets):raiseTypeError(f"'{self.__class__.__name__}' currently only supports targets "f"of type '{NumericalTarget.__name__}'.")iflen({t.namefortintargets})!=len(targets):raiseValueError("All target names must be unique.")ifnotall(target._is_transform_normalizedfortargetintargets):raiseValueError("All targets must have normalized computational representations to ""enable the computation of desirability values. This requires having ""appropriate target bounds and transformations in place.")@weights.validatordef_validate_weights(self,_,weights)->None:# noqa: DOC101, DOC103if(lw:=len(weights))!=(lt:=len(self.targets)):raiseValueError(f"If custom weights are specified, there must be one for each target. "f"Specified number of targets: {lt}. Specified number of weights: {lw}.")@propertydeftargets(self)->tuple[Target,...]:# noqa: D102# See base class.returnself._targets@cached_propertydef_normalized_weights(self)->np.ndarray:"""The normalized target weights."""returnnp.asarray(self.weights)/np.sum(self.weights)def__str__(self)->str:start_bold="\033[1m"end_bold="\033[0m"targets_list=[target.summary()fortargetinself.targets]targets_df=pd.DataFrame(targets_list)targets_df["Weight"]=self.weightsobjective_str=f"""{start_bold}Objective{end_bold}\n{start_bold}Type: {end_bold}{self.__class__.__name__}\n{start_bold}Targets {end_bold}\n{targets_df}\n{start_bold}Scalarizer: {end_bold}{self.scalarizer.name}"""returnobjective_str.replace("\n","\n ")
[docs]deftransform(self,data:pd.DataFrame)->pd.DataFrame:# noqa: D102# See base class.# Transform all targets individuallytransformed=data[[t.namefortinself.targets]].copy()fortargetinself.targets:transformed[target.name]=target.transform(data[[target.name]])# Scalarize the transformed targets into desirability valuesvals=scalarize(transformed.values,self.scalarizer,self._normalized_weights)# Store the total desirability in a dataframe columntransformed=pd.DataFrame({"Desirability":vals},index=transformed.index)returntransformed