Source code for baybe.scaler
"""Functionality for data scaling."""
from __future__ import annotations
from abc import ABC, abstractmethod
from collections.abc import Callable
from typing import TYPE_CHECKING
import pandas as pd
from baybe.utils.dataframe import to_tensor
if TYPE_CHECKING:
from torch import Tensor
_ScaleFun = Callable[[Tensor], Tensor]
[docs]
class Scaler(ABC):
"""Abstract base class for all scalers.
Args:
searchspace: The search space that should be scaled.
"""
type: str
"""Class variable encoding the type of the scaler."""
SUBCLASSES: dict[str, type[Scaler]] = {}
"""Class variable for all subclasses"""
[docs]
def __init__(self, searchspace: pd.DataFrame):
self.searchspace = searchspace
self.fitted = False
self.scale_x: _ScaleFun
self.scale_y: _ScaleFun
self.unscale_x: _ScaleFun
self.unscale_y: _ScaleFun
self.unscale_m: _ScaleFun
self.unscale_s: _ScaleFun
[docs]
@abstractmethod
def fit_transform(self, x: Tensor, y: Tensor) -> tuple[Tensor, Tensor]:
"""Fit the scaler using the given training data and transform the data.
Args:
x: The x-data that should be used.
y: The y-data that should be used.
Returns:
The transformed data.
"""
[docs]
def transform(self, x: Tensor) -> Tensor:
"""Scale a given input.
Args:
x: The given input.
Returns:
The scaled input.
Raises:
RuntimeError: If the scaler is not fitted first.
"""
if not self.fitted:
raise RuntimeError("Scaler object must be fitted first.")
return self.scale_x(x)
[docs]
def untransform(self, mean: Tensor, variance: Tensor) -> tuple[Tensor, Tensor]:
"""Transform mean values and variances back to the original domain.
Args:
mean: The given mean values.
variance: The given variances.
Returns:
The "un-transformed" means and variances.
Raises:
RuntimeError: If the scaler object is not fitted first.
"""
if not self.fitted:
raise RuntimeError("Scaler object must be fitted first.")
return self.unscale_m(mean), self.unscale_s(variance)
@classmethod
def __init_subclass__(cls, **kwargs):
"""Register new subclasses dynamically."""
super().__init_subclass__(**kwargs)
cls.SUBCLASSES[cls.type] = cls
[docs]
class DefaultScaler(Scaler):
"""A scaler that normalizes inputs to the unit cube and standardizes targets."""
type = "DEFAULT"
# See base class.
[docs]
def fit_transform( # noqa: D102
self, x: Tensor, y: Tensor
) -> tuple[Tensor, Tensor]:
# See base class.
import torch
# Get the searchspace boundaries
searchspace = to_tensor(self.searchspace)
bounds = torch.vstack(
[torch.min(searchspace, dim=0)[0], torch.max(searchspace, dim=0)[0]]
)
# Compute the mean and standard deviation of the training targets
mean = torch.mean(y, dim=0)
std = torch.std(y, dim=0)
# Functions for input and target scaling
self.scale_x = lambda x: (x - bounds[0]) / (bounds[1] - bounds[0])
self.scale_y = lambda x: (x - mean) / std
# Functions for inverse input and target scaling
self.unscale_x = lambda x: x * (bounds[1] - bounds[0]) + bounds[0]
self.unscale_y = lambda x: x * std + mean
# Functions for inverse mean and variance scaling
self.unscale_m = lambda x: x * std + mean
self.unscale_s = lambda x: x * std**2
# Flag that the scaler has been fitted
self.fitted = True
return self.scale_x(x), self.scale_y(y)