Source code for gpytorchwrapper.src.utils.input_transformer

import torch


[docs] def invdist(x: torch.Tensor): return torch.pow(x, -1)
[docs] def morse(x: torch.Tensor): return torch.exp(-x)
[docs] def xyz_to_dist_torch( x: torch.Tensor, index: bool = False, representation: str = "invdist" ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: """ x is a tensor of shape (n, m) where m is the number of individual x, y, z coordinates and n is the number of data points. The x, y, z coordinates must be ordered as x1, y1, z1, x2, y2, z2, ... , xn, yn, zn The final tensor containing the interatomic distances will have the shape (n, m/3) where m/3 is the number of atoms. The order of the distances is d01, d02, ..., d12, d13, ..., d(m/3-2)(m/3-1) Parameters ---------- representation x : torch.Tensor The input tensor of shape (n, m) where m is the number of individual x, y, z coordinates index : bool returns unique atom indices per distance Returns ------- torch.Tensor | tuple[torch.Tensor, torch.Tensor] The inverse interatomic distances tensor or a tuple containing the inverse interatomic distances tensor and the unique atom indices per distance """ if len(x.shape) == 1: x = x.reshape(1, -1) elif len(x.shape) > 2: x = x.reshape(x.shape[0], -1) n, m = x.shape num_atoms = m // 3 coords = x.reshape(n, num_atoms, 3) # Calculate pairwise distances diff = coords[:, :, None, :] - coords[:, None, :, :] dist = torch.sqrt(torch.sum(diff**2, dim=-1) + 1e-8) # Create a mask to zero out the diagonal (self-distances) mask = torch.eye(num_atoms, dtype=torch.bool) dist = dist.masked_fill(mask, 0) # Upper triangular indices triu_indices = torch.triu_indices(num_atoms, num_atoms, offset=1) # Get the upper triangular part of the distance matrix interdist = dist[:, triu_indices[0], triu_indices[1]] if representation == "invdist": interdist = invdist(interdist) elif representation == "morse": interdist = morse(interdist) if index: return interdist, torch.transpose(triu_indices, -1, -2) return interdist