import math
from typing import Optional, Union
import torch
from gpytorch.constraints import Interval, Positive
from gpytorch.priors import Prior
from linear_operator.operators import MatmulLinearOperator, RootLinearOperator
from gpytorchwrapper.src.kernels.perminv_kernel import PermInvKernel
from gpytorchwrapper.src.utils.input_transformer import xyz_to_dist_torch
[docs]
class LinearxMaternKernelPermInv(PermInvKernel):
has_lengthscale = True
def __init__(
self,
n_atoms: int,
idx_equiv_atoms: list[list[int]],
select_dims: list[int] = None,
nu: float = 2.5,
ard: bool = False,
representation: str = "invdist",
variance_prior: Optional[Prior] = None,
variance_constraint: Optional[Interval] = None,
**kwargs,
):
"""
Initialize the LinearxMaternKernelPermInv kernel, a product kernel of a linear kernel and a Matern kernel.
Parameters
----------
n_atoms : int
Number of atoms in the molecule or structure.
idx_equiv_atoms : list of list of int
Groups of indices indicating equivalent atoms under permutations.
select_dims : list of int, optional
Dimensions to select from the distance representation.
nu : float, default=2.5
Smoothness parameter of the Matérn kernel. Must be one of {0.5, 1.5, 2.5}.
ard : bool, default=False
If True, use automatic relevance determination (ARD).
representation : str, default="invdist"
The type of representation to use for distances, choose from:
`invdist` for inverse distances
`morse` for features exp(-r_ij)
variance_prior : gpytorch.priors.Prior, optional
Prior distribution for the variance parameter.
variance_constraint : gpytorch.constraints.Interval, optional
Constraint for the variance parameter.
**kwargs
Additional keyword arguments for the base class.
Raises
------
ValueError
If `nu` is not one of {0.5, 1.5, 2.5}.
NotImplementedError
If `active_dims` is provided in `kwargs`, which is not supported.
TypeError
If `variance_prior` is not an instance of `gpytorch.priors.Prior`.
"""
super().__init__(
n_atoms=n_atoms,
idx_equiv_atoms=idx_equiv_atoms,
select_dims=select_dims,
ard=ard,
**kwargs,
)
if nu not in {0.5, 1.5, 2.5}:
raise ValueError(
"Please select one of the following nu values: {0.5, 1.5, 2.5}"
)
if self.active_dims is not None:
raise NotImplementedError(
"Keyword active_dims is not supported for LinearxMaternKernelPermInv. Please use select_dims instead."
)
if variance_constraint is None:
variance_constraint = Positive()
self.register_parameter(
name="raw_variance",
parameter=torch.nn.Parameter(
torch.zeros(
*self.batch_shape,
1,
1 if self.ard_num_dims is None else self.ard_num_dims,
)
),
)
if variance_prior is not None:
if not isinstance(variance_prior, Prior):
raise TypeError(
"Expected gpytorch.priors.Prior but got "
+ type(variance_prior).__name__
)
self.register_prior(
"variance_prior",
variance_prior,
lambda m: m.variance,
lambda m, v: m._set_variance(v),
)
self.register_constraint("raw_variance", variance_constraint)
self.nu = nu
self.representation = representation
@property
def variance(self) -> torch.Tensor:
return self.raw_variance_constraint.transform(self.raw_variance)
@variance.setter
def variance(self, value: Union[float, torch.Tensor]):
self._set_variance(value)
def _set_variance(self, value: Union[float, torch.Tensor]):
if not torch.is_tensor(value):
value = torch.as_tensor(value).to(self.raw_variance)
self.initialize(
raw_variance=self.raw_variance_constraint.inverse_transform(value)
)
[docs]
def matern_kernel(self, x1, x2, diag, **params):
mean = x1.mean(dim=-2, keepdim=True)
if self.ard:
perminv_ard_lengthscale = self.lengthscale.clone()[0][
self.ard_expansion
].unsqueeze(0)
x1_ = (x1 - mean).div(perminv_ard_lengthscale)
x2_ = (x2 - mean).div(perminv_ard_lengthscale)
else:
x1_ = (x1 - mean).div(self.lengthscale)
x2_ = (x2 - mean).div(self.lengthscale)
distance = self.covar_dist(x1_, x2_, diag=diag, **params)
exp_component = torch.exp(-math.sqrt(self.nu * 2) * distance)
if self.nu == 0.5:
constant_component = 1
elif self.nu == 1.5:
constant_component = (math.sqrt(3) * distance).add(1)
elif self.nu == 2.5:
constant_component = (
(math.sqrt(5) * distance).add(1).add(5.0 / 3.0 * distance**2)
)
else:
raise ValueError(
"Please select one of the following nu values: {0.5, 1.5, 2.5}"
)
return constant_component * exp_component
[docs]
def linear_kernel(self, x1, x2, diag, last_dim_is_batch, **params):
if self.ard:
perminv_ard_variance = self.variance.clone()[0][
self.ard_expansion
].unsqueeze(0)
x1_ = x1 * perminv_ard_variance.sqrt()
else:
x1_ = x1 * self.variance.sqrt()
if last_dim_is_batch:
x1_ = x1_.transpose(-1, -2).unsqueeze(-1)
if x1.size() == x2.size() and torch.equal(x1, x2):
# Use RootLinearOperator when x1 == x2 for efficiency when composing
# with other kernels
prod = RootLinearOperator(x1_)
else:
if self.ard:
x2_ = x2 * perminv_ard_variance.sqrt()
else:
x2_ = x2 * self.variance.sqrt()
if last_dim_is_batch:
x2_ = x2_.transpose(-1, -2).unsqueeze(-1)
prod = MatmulLinearOperator(x1_, x2_.transpose(-2, -1))
if diag:
return prod.diagonal(dim1=-1, dim2=-2)
else:
return prod
[docs]
def forward(
self, x1, x2, diag=False, last_dim_is_batch: Optional[bool] = False, **params
):
k_sum = 0
num_perms = len(self.permutations)
init_perm = self.permutations[0]
x1_dist = xyz_to_dist_torch(x1, representation=self.representation)
x2_dist = (
xyz_to_dist_torch(x2, representation=self.representation)
if not torch.equal(x1, x2)
else x1_dist.clone()
)
if self.select_dims is not None:
select_dims_tensor = torch.tensor(self.select_dims)
x1_dist = torch.index_select(x1_dist, 1, select_dims_tensor)
for perm in self.permutations:
x2_dist_perm = x2_dist.clone()
x2_dist_perm[:, init_perm] = x2_dist[:, perm]
if self.select_dims is not None:
x2_dist_perm = torch.index_select(x2_dist_perm, 1, select_dims_tensor)
k_linear = self.linear_kernel(
x1_dist, x2_dist_perm, diag, last_dim_is_batch, **params
)
k_matern = self.matern_kernel(x1_dist, x2_dist_perm, diag, **params)
k_sum += k_linear * k_matern
return 1 / num_perms * k_sum