Source code for gpytorchwrapper.plugins.model_h2okr_s1

from gpytorch import kernels, means, models, distributions

from gpytorchwrapper.src.kernels.linearxmatern_kernel_perminv import (
    LinearxMaternKernelPermInv,
)
from botorch.models.gpytorch import GPyTorchModel


[docs] class H2OKrS1(models.ExactGP, GPyTorchModel): def __init__(self, train_x, train_y, likelihood): super().__init__(train_x, train_y, likelihood) n_atoms = 4 idx_equiv_atoms = [[1, 2]] dims = [2, 4, 5] #l_constraint = constraints.GreaterThan(0.075) self.mean_module = means.ConstantMean() self.covar_module = kernels.ScaleKernel( LinearxMaternKernelPermInv( n_atoms=n_atoms, idx_equiv_atoms=idx_equiv_atoms, select_dims=dims, ) ) self.mean_module.raw_constant.requires_grad = False
[docs] def forward(self, x): mean_x = self.mean_module(x) covar_x = self.covar_module(x) return distributions.MultivariateNormal(mean_x, covar_x)