Source code for gpytorchwrapper.plugins.model_h2okr_s0

from gpytorch import kernels, means, models, distributions

from gpytorchwrapper.src.kernels import (
    PolyxMaternKernelPermInv,
)
from botorch.models.gpytorch import GPyTorchModel


[docs] class H2OKrS0(models.ExactGP, GPyTorchModel): def __init__(self, train_x, train_y, likelihood): super().__init__(train_x, train_y, likelihood) n_atoms = 4 idx_equiv_atoms = [[1, 2]] dims = [2, 4, 5] # l_constraint = constraints.GreaterThan(0.075) self.mean_module = means.ConstantMean() self.covar_module = kernels.ScaleKernel( PolyxMaternKernelPermInv( n_atoms=n_atoms, idx_equiv_atoms=idx_equiv_atoms, select_dims=dims, ard=True, nu=2.5, power=1, representation="morse", ) ) self.mean_module.constant = 0.0 self.mean_module.raw_constant.requires_grad = False
[docs] def forward(self, x): mean_x = self.mean_module(x) covar_x = self.covar_module(x) return distributions.MultivariateNormal(mean_x, covar_x)