Source code for gpytorchwrapper.src.kernels.rbf_kernel_perminv

from typing import Optional

import torch
from gpytorch.kernels.rbf_kernel import postprocess_rbf

from gpytorchwrapper.src.kernels.perminv_kernel import PermInvKernel
from gpytorchwrapper.src.utils.input_transformer import xyz_to_dist_torch


[docs] class RBFKernelPermInv(PermInvKernel): has_lengthscale = True def __init__( self, n_atoms: int, idx_equiv_atoms: list[list[int]], select_dims: list[int] = None, ard: bool = False, representation: str = "invdist", **kwargs, ): """ Initialize the RBFKernelPermInv kernel. Parameters ---------- n_atoms : int Number of atoms in the molecule or structure. idx_equiv_atoms : list of list of int Groups of indices indicating equivalent atoms under permutations. select_dims : list of int, optional Dimensions to select from the distance representation. ard : bool, default=False If True, use automatic relevance determination (ARD). representation : str, default="invdist" The type of representation to use for distances, choose from: `invdist` for inverse distances `morse` for features exp(-r_ij) **kwargs Additional keyword arguments for the base class. Raises ------ NotImplementedError If `active_dims` is provided in `kwargs`, which is not supported. """ super().__init__( n_atoms=n_atoms, idx_equiv_atoms=idx_equiv_atoms, select_dims=select_dims, ard=ard, **kwargs, ) if self.active_dims is not None: raise NotImplementedError( "Keyword active_dims is not supported for this kernel. Please use select_dims instead." ) self.representation = representation
[docs] def rbf_kernel(self, x1, x2, diag, **params): mean = x1.mean(dim=-2, keepdim=True) if self.ard: perminv_ard_lengthscale = self.lengthscale.clone()[0][ self.ard_expansion ].unsqueeze(0) x1_ = (x1 - mean).div(perminv_ard_lengthscale) x2_ = (x2 - mean).div(perminv_ard_lengthscale) else: x1_ = (x1 - mean).div(self.lengthscale) x2_ = (x2 - mean).div(self.lengthscale) return postprocess_rbf( self.covar_dist(x1_, x2_, square_dist=True, diag=diag, **params) )
[docs] def forward( self, x1, x2, diag=False, last_dim_is_batch: Optional[bool] = False, **params ): k_sum = 0 num_perms = len(self.permutations) init_perm = self.permutations[0] x1_dist = xyz_to_dist_torch(x1, representation=self.representation) x2_dist = ( xyz_to_dist_torch(x2, representation=self.representation) if not torch.equal(x1, x2) else x1_dist.clone() ) if self.select_dims is not None: select_dims_tensor = torch.tensor(self.select_dims) x1_dist = torch.index_select(x1_dist, 1, select_dims_tensor) for perm in self.permutations: x2_dist_perm = x2_dist.clone() x2_dist_perm[:, init_perm] = x2_dist[:, perm] if self.select_dims is not None: x2_dist_perm = torch.index_select(x2_dist_perm, 1, select_dims_tensor) k_sum += self.rbf_kernel(x1_dist, x2_dist_perm, diag=diag, **params) return 1 / num_perms * k_sum