Source code for gpytorchwrapper.src.kernels.perminv_kernel

from typing import Optional

import torch
from gpytorch.kernels import Kernel

from gpytorchwrapper.src.utils.permutational_invariance import (
    generate_interatomic_distance_indices,
    generate_unique_distances,
    generate_ard_expansion,
    generate_dist_permutations,
)


[docs] class PermInvKernel(Kernel): """ This class is a base class for all permutationally invariant kernels. It handles the ARD expansion automatically and provides a method to generate the permutations of the distances. Parameters ---------- n_atoms: int The total number of atoms in the system idx_equiv_atoms: list[list[int]] List of lists representing the groups of permutationally invariant atoms select_dims: Tensor The indices of the dimensions to be selected ard: bool Whether to use ARD or not kwargs Additional keyword arguments Attributes ---------- select_dims: Tensor The indices of the dimensions to be selected ard: bool Whether to use ARD or not Raises ------ NotImplementedError If the `ard_num_dims` keyword is used instead of `ard`. ValueError If the expected number of unique distances does not match the amount generated by the ARD expansion. """ def __init__( self, n_atoms: int, idx_equiv_atoms: list[list[int]], select_dims: Optional[list[int]] = None, ard: bool = False, **kwargs, ): n_dist = n_atoms * (n_atoms - 1) // 2 distance_idx = generate_interatomic_distance_indices(n_atoms) # No ARD model requested if not ard: super().__init__(**kwargs) if self.ard_num_dims is not None: raise NotImplementedError( "Regular ARD is not supported for LinearxMaternKernelPermInv. Set 'ard=True' instead and specify ard_expansion." ) # ARD model requested else: ard_num_dims = n_dist if not select_dims else len(select_dims) num_unique_distances = generate_unique_distances( n_atoms, idx_equiv_atoms ) # permutationally unique! if select_dims: distance_idx = generate_interatomic_distance_indices(n_atoms) distance_idx = [distance_idx[i] for i in select_dims] ard_expansion = generate_ard_expansion(distance_idx, idx_equiv_atoms) self.register_buffer( "select_dims", torch.tensor(select_dims, dtype=torch.int) ) else: self.select_dims = None ard_expansion = generate_ard_expansion(distance_idx, idx_equiv_atoms) if num_unique_distances != len(set(ard_expansion)): raise ValueError( "The permutationally invariant ARD expansion failed." f"Expected number of unique distances {num_unique_distances} != {len(set(ard_expansion))}" f"ARD expansion: {ard_expansion}" ) super().__init__(ard_num_dims=ard_num_dims, **kwargs) self.register_buffer("ard_expansion", torch.tensor(ard_expansion)) permutations = generate_dist_permutations(distance_idx, idx_equiv_atoms) self.register_buffer("permutations", permutations) self.ard = ard