Source code for baybe.objectives.botorch
"""BoTorch objectives."""
from botorch.acquisition.objective import MCAcquisitionObjective
from torch import Tensor
from baybe.utils.basic import compose
[docs]
class ChainedMCObjective(MCAcquisitionObjective):
"""A chained Monte Carlo objective."""
[docs]
def __init__(self, *objectives: MCAcquisitionObjective) -> None:
super().__init__()
self.objectives = objectives
[docs]
def forward(self, samples: Tensor, X: Tensor | None = None) -> Tensor: # noqa: D102
return compose(*(o.forward for o in self.objectives))(samples, X)