from gpytorch import kernels, means, models, distributions, constraints, priors
from gpytorchwrapper.src.kernels.polyxmatern_kernel_perminv import (
PolyxMaternKernelPermInv,
)
from botorch.models.gpytorch import GPyTorchModel
[docs]
class ArH2pS0(models.ExactGP, GPyTorchModel):
def __init__(self, train_x, train_y, likelihood):
super().__init__(train_x, train_y, likelihood)
outputscale_prior = priors.NormalPrior(5.0, 2.0)
lengthscale_prior = priors.NormalPrior(0.5, 0.4)
variance_prior = priors.NormalPrior(0.5, 0.4)
n_atoms = 3
idx_equiv_atoms = [[0, 1]]
self.mean_module = means.ConstantMean()
self.covar_module = kernels.ScaleKernel(
PolyxMaternKernelPermInv(
n_atoms=n_atoms,
idx_equiv_atoms=idx_equiv_atoms,
ard=True,
nu=2.5,
lengthscale_prior=lengthscale_prior,
power=1,
representation="morse",
variance_constraint=constraints.Positive(),
)
)
self.covar_module.base_kernel.lengthscale = [lengthscale_prior.mean] * 3
self.covar_module.base_kernel.variance = [variance_prior.mean] * 3
self.covar_module.outputscale = outputscale_prior.mean
self.mean_module.constant = 4.0
self.mean_module.raw_constant.requires_grad = False
[docs]
def forward(self, x):
mean_x = self.mean_module(x)
covar_x = self.covar_module(x)
return distributions.MultivariateNormal(mean_x, covar_x)