import itertools
import torch
[docs]
def generate_permutations(idx_equiv_atoms: list[list[int]]) -> torch.Tensor:
"""
Parameters
----------
idx_equiv_atoms : list[list[int]]
List containing lists of indices for equivalent atoms
Returns
-------
torch.Tensor
Tensor of all possible permutations
Example
-------
For the reaction between N2 and H3+, the nitrogen atoms have indices 0 and 1,
while the hydrogen atoms have indices 2, 3, and 4.
The `idx_equiv_atoms` list should look like [[0,1],[2,3,4]] or [[1,2]].
"""
all_perms = []
for group in idx_equiv_atoms:
all_perms.append([list(p) for p in itertools.permutations(group)])
perms = [list(tup) for tup in itertools.product(*all_perms)]
perms = [sum((sublist for sublist in item), []) for item in perms]
perms = torch.tensor(perms)
return perms
[docs]
def generate_unique_distances(num_atoms: int, idx_equiv_atoms: list[list[int]]) -> int:
"""
Parameters
----------
num_atoms: int
The total number of atoms in the system
idx_equiv_atoms: list[list[int]]
List of lists representing the groups of permutationally invariant atoms
Returns
-------
num_unique_dist: int
The number of unique distances in the system taking into account permutational invariance
Examples
--------
The H2O system contains two permutationally invariant hydrogen atoms H1 and H2. The energy is invariant to
the permutation of the distances O-H1 and O-H2.
Therefore there are 2 unique distances in the system: O-H and H-H.
The general formula is,
unique distances = n(n-1)/2 + k,
where n is the number of atom groups and k is the number of groups containing more than a single atom.
"""
num_perminv_groups = len(idx_equiv_atoms)
num_reg_groups = num_atoms - sum([len(i) for i in idx_equiv_atoms])
num_groups = num_reg_groups + num_perminv_groups
num_unique_dist = (num_groups * (num_groups - 1)) // 2 + num_perminv_groups
return num_unique_dist
[docs]
def generate_interatomic_distance_indices(num_atoms: int) -> list[list[int]]:
distance_indices: list[list[int]] = []
for atom1 in range(num_atoms):
for atom2 in range(num_atoms):
if atom1 != atom2:
distance_idx = sorted([atom1, atom2])
if distance_idx not in distance_indices:
distance_indices.append(distance_idx)
return distance_indices
[docs]
def generate_ard_expansion(
distance_idx: list[list[int]], idx_inv_atoms: list[list[int]]
) -> list:
group_labels = {}
# Flatten list to loop over
flat_distance_idx = [atom for dist in distance_idx for atom in dist]
for idx, atom in enumerate(flat_distance_idx):
for inv_group in idx_inv_atoms:
if atom in inv_group:
flat_distance_idx[idx] = min(inv_group)
# Reshape list to 2D for setting as dictionary keys
mapped_distance_idx = [
flat_distance_idx[i : i + 2] for i in range(0, len(flat_distance_idx), 2)
]
# Lists are not hashable so turn into tuples and sort for unique representation
mapped_distance_idx = [tuple(sorted(dist)) for dist in mapped_distance_idx]
for dist in mapped_distance_idx:
if not group_labels: # If empty
group_labels[dist] = 0
elif dist not in group_labels.keys():
group_labels[dist] = max(group_labels.values()) + 1
expansion = [group_labels[dist] for dist in mapped_distance_idx]
return expansion
[docs]
def generate_dist_permutations(
distance_idx: list[list[int]], idx_inv_atoms: list[list[int]]
) -> torch.Tensor:
group_labels = {}
# Flatten list to loop over
flat_distance_idx = [atom for dist in distance_idx for atom in dist]
for idx, atom in enumerate(flat_distance_idx):
for inv_group in idx_inv_atoms:
if atom in inv_group:
flat_distance_idx[idx] = min(inv_group)
mapped_distance_idx = [
flat_distance_idx[i : i + 2] for i in range(0, len(flat_distance_idx), 2)
]
mapped_distance_idx = [tuple(sorted(dist)) for dist in mapped_distance_idx]
for i, dist in enumerate(mapped_distance_idx):
if dist in group_labels.keys():
group_labels[dist].append(i)
else:
group_labels[dist] = [i]
grouped_distance_idx = [x for x in group_labels.values() if len(x) > 1]
dist_permutations = generate_permutations(grouped_distance_idx)
return dist_permutations