import math
from typing import Optional
import torch
from gpytorch.constraints import Interval, Positive
from gpytorch.priors import Prior
from gpytorchwrapper.src.kernels.perminv_kernel import PermInvKernel
from gpytorchwrapper.src.utils.input_transformer import xyz_to_dist_torch
[docs]
class PolyxMaternKernelPermInv(PermInvKernel):
has_lengthscale = True
def __init__(
self,
n_atoms: int,
idx_equiv_atoms: list[list[int]],
power: int,
select_dims: list[int] = None,
nu: float = 2.5,
ard: bool = False,
representation: str = "invdist",
offset_prior: Optional[Prior] = None,
offset_constraint: Optional[Interval] = None,
variance_prior: Optional[Prior] = None,
variance_constraint: Optional[Interval] = None,
**kwargs,
):
"""
Initialize the PolyxMaternKernelPermInv kernel, the product kernel of a polynomial kernel and a Matern kernel.
The polynomial kernel reduces to a linear kernel with an offset if `power = 1`.
Parameters
----------
n_atoms : int
Number of atoms in the molecule or structure.
idx_equiv_atoms : list of list of int
Groups of indices indicating equivalent atoms under permutations.
power : int
The exponent used in the polynomial kernel component.
select_dims : list of int, optional
Dimensions to select from the distance representation.
nu : float, default=2.5
Smoothness parameter of the Matérn kernel. Must be one of {0.5, 1.5, 2.5}.
ard : bool, default=False
If True, use automatic relevance determination (ARD).
representation : str, default="invdist"
The type of representation to use for distances, choose from:
`invdist` for inverse distances
`morse` for features exp(-r_ij)
offset_prior : gpytorch.priors.Prior, optional
Prior distribution for the offset parameter.
offset_constraint : gpytorch.constraints.Interval, optional
Constraint for the offset parameter.
variance_prior : gpytorch.priors.Prior, optional
Prior distribution for the variance parameter.
variance_constraint : gpytorch.constraints.Interval, optional
Constraint for the variance parameter.
**kwargs
Additional keyword arguments for the base class.
Raises
------
NotImplementedError
If `nu` is not one of {0.5, 1.5, 2.5}.
NotImplementedError
If `active_dims` is provided in `kwargs`, which is not supported.
RuntimeError
If `power` is a tensor with more than one element.
TypeError
If `offset_prior` or `variance_prior` is not an instance of `gpytorch.priors.Prior`.
"""
super().__init__(
n_atoms=n_atoms,
idx_equiv_atoms=idx_equiv_atoms,
select_dims=select_dims,
ard=ard,
**kwargs,
)
if nu not in {0.5, 1.5, 2.5}:
raise NotImplementedError(
"Please select one of the following nu values: {0.5, 1.5, 2.5}"
)
if self.active_dims is not None:
raise NotImplementedError(
"Keyword active_dims is not supported for LinearxMaternKernelPermInv. Please use select_dims instead."
)
if offset_constraint is None:
offset_constraint = Positive()
self.register_parameter(
name="raw_offset",
parameter=torch.nn.Parameter(torch.zeros(*self.batch_shape, 1)),
)
# We want the power to be a float so we dont have to worry about its device / dtype.
if torch.is_tensor(power):
if power.numel() > 1:
raise RuntimeError(
"Cant create a Polynomial kernel with more than one power"
)
else:
power = power.item()
self.power = power
if offset_prior is not None:
if not isinstance(offset_prior, Prior):
raise TypeError(
"Expected gpytorch.priors.Prior but got "
+ type(offset_prior).__name__
)
self.register_prior(
"offset_prior",
offset_prior,
lambda m: m.offset,
lambda m, v: m._set_offset(v),
)
self.register_constraint("raw_offset", offset_constraint)
if variance_constraint is None:
variance_constraint = Positive()
self.register_parameter(
name="raw_variance",
parameter=torch.nn.Parameter(
torch.zeros(
*self.batch_shape,
1,
1 if self.ard_num_dims is None else self.ard_num_dims,
)
),
)
if variance_prior is not None:
if not isinstance(variance_prior, Prior):
raise TypeError(
"Expected gpytorch.priors.Prior but got "
+ type(variance_prior).__name__
)
self.register_prior(
"variance_prior",
variance_prior,
lambda m: m.variance,
lambda m, v: m._set_variance(v),
)
self.register_constraint("raw_variance", variance_constraint)
self.nu = nu
self.representation = representation
@property
def offset(self) -> torch.Tensor:
return self.raw_offset_constraint.transform(self.raw_offset)
@offset.setter
def offset(self, value: torch.Tensor) -> None:
self._set_offset(value)
def _set_offset(self, value: torch.Tensor) -> None:
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_offset)
self.initialize(raw_offset=self.raw_offset_constraint.inverse_transform(value))
@property
def variance(self) -> torch.Tensor:
return self.raw_variance_constraint.transform(self.raw_variance)
@variance.setter
def variance(self, value: float | torch.Tensor):
self._set_variance(value)
def _set_variance(self, value: float | torch.Tensor):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_variance)
self.initialize(
raw_variance=self.raw_variance_constraint.inverse_transform(value)
)
[docs]
def matern_kernel(self, x1, x2, diag, **params):
mean = x1.mean(dim=-2, keepdim=True)
if self.ard:
perminv_ard_lengthscale = self.lengthscale.clone()[0][
self.ard_expansion
].unsqueeze(0)
x1_ = (x1 - mean).div(perminv_ard_lengthscale)
x2_ = (x2 - mean).div(perminv_ard_lengthscale)
else:
x1_ = (x1 - mean).div(self.lengthscale)
x2_ = (x2 - mean).div(self.lengthscale)
distance = self.covar_dist(x1_, x2_, diag=diag, **params)
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)
if self.nu == 0.5:
constant_component = 1
elif self.nu == 1.5:
constant_component = (math.sqrt(3) * distance).add(1)
elif self.nu == 2.5:
constant_component = (
(math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2)
)
else:
raise NotImplementedError(
"Please select one of the following nu values: {0.5, 1.5, 2.5}"
)
return constant_component * exp_component
[docs]
def polynomial_kernel(self, x1, x2, diag, last_dim_is_batch, **params):
offset = self.offset.view(*self.batch_shape, 1, 1)
if self.ard:
perminv_ard_variance = self.variance.clone()[0][
self.ard_expansion
].unsqueeze(0)
x1_ = x1 * perminv_ard_variance.sqrt()
x2_ = x2 * perminv_ard_variance.sqrt()
else:
x1_ = x1 * self.variance.sqrt()
x2_ = x2 * self.variance.sqrt()
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)
if diag:
return ((x1_ * x2_).sum(dim=-1) + self.offset).pow(self.power)
if (x1_.dim() == 2 and x2_.dim() == 2) and offset.dim() == 2:
return torch.addmm(offset, x1_, x2_.transpose(-2, -1)).pow(self.power)
else:
return (torch.matmul(x1_, x2_.transpose(-2, -1)) + offset).pow(self.power)
[docs]
def forward(
self, x1, x2, diag=False, last_dim_is_batch: Optional[bool] = False, **params
):
k_sum = 0
num_perms = len(self.permutations)
init_perm = self.permutations[0]
x1_dist = xyz_to_dist_torch(x1, representation=self.representation)
x2_dist = (
xyz_to_dist_torch(x2, representation=self.representation)
if not torch.equal(x1, x2)
else x1_dist.clone()
)
if self.select_dims is not None:
select_dims_tensor = torch.tensor(self.select_dims)
x1_dist = torch.index_select(x1_dist, 1, select_dims_tensor)
for perm in self.permutations:
x2_dist_perm = x2_dist.clone()
x2_dist_perm[:, init_perm] = x2_dist[:, perm]
if self.select_dims is not None:
x2_dist_perm = torch.index_select(x2_dist_perm, 1, select_dims_tensor)
k_poly = self.polynomial_kernel(
x1_dist, x2_dist_perm, diag, last_dim_is_batch, **params
)
k_matern = self.matern_kernel(x1_dist, x2_dist_perm, diag, **params)
k_sum += k_poly * k_matern
return 1 / num_perms * k_sum