Coverage for hiphive/core/utilities.py: 99%
109 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
1"""
2The ``utilities`` module contains various support functions and classes.
3"""
4import sympy
5import numpy as np
6from scipy.sparse import coo_matrix
7from collections import defaultdict
8from ase import Atoms
11def ase_atoms_to_spglib_tuple(structure: Atoms) -> tuple:
12 """
13 Returns a tuple of three components: cell metric, atomic positions, and
14 atomic species of the input ASE Atoms object.
15 """
16 return (structure.cell, structure.get_scaled_positions(), structure.get_atomic_numbers())
19class SparseMatrix(sympy.SparseMatrix):
21 def rref_sparse(self, simplify=False):
22 """ A sparse row reduce algorithm mimicing the dense version. """
24 def multiply_and_add_row(row1, factor, row2):
25 """ Does row1 += factor * row2 when rows are represented by dicts.
26 """
27 keys_to_delete = []
28 for k, v in row2.items():
29 row1[k] += factor * v
30 if row1[k] == 0:
31 keys_to_delete.append(k)
32 for k in keys_to_delete:
33 del row1[k]
35 # The matrix is represented as a list of dicts where each dict is a row
36 M = defaultdict(lambda: defaultdict(lambda: 0))
38 # Init our special representation. This is possible due to defaultdict
39 for r, c, v in self.row_list():
40 M[r][c] = v
41 M = list(M.values())
43 # The pivot elements is stored as tuples as (row, col)
44 pivots = []
45 r, c = 0, 0 # current row and col of possible pivot
46 nRows, nCols = len(M), self.shape[1]
47 while r < nRows and c < nCols:
48 row = M[r]
49 # check if proposed pivot i zero. if so swap this row with a row
50 # below which has non zero element at that col
51 if c not in row:
52 for r2 in range(r + 1, nRows):
53 row2 = M[r2]
54 if c in row2: # This row has element in the current col
55 M[r], M[r2] = row2, row # swap the rows
56 row = row2
57 break
58 else: # The pivot and all elements below are zero.
59 c += 1 # goto next column but stay on row
60 continue
61 pivots.append((r, c))
62 # Normalize row
63 row_c = row[c]
64 for k in row:
65 row[k] /= row_c
66 # Start elimination
67 r2 = r + 1
68 while r2 < nRows:
69 row2 = M[r2]
70 if c not in row2:
71 r2 += 1
72 continue
73 multiply_and_add_row(row2, -row2[c], row)
74 if len(row2) == 0:
75 nRows -= 1
76 del M[r2]
77 continue
78 r2 += 1
79 r += 1
80 c += 1
82 # Eliminate elements above pivots
83 for (r, p) in pivots:
84 row_p = M[r]
85 for i in range(r):
86 row = M[i]
87 if p in row:
88 multiply_and_add_row(row, -row[p], row_p)
90 # Create the new rrefd matrix
91 M2 = SparseMatrix(*self.shape, 0)
92 for i, d in enumerate(M):
93 for j, v in d.items():
94 M2[i, j] = v
96 pivots = tuple(p[1] for p in pivots)
98 return M2, pivots
100 def nullspace(self, simplify=False):
101 """ This is a slightly patched version which also uses the sparse rref
102 and is faster due to up-front creation of empty SparseMatrix
103 vectors instead of conversion of the finished vectors.
104 """
105 if (max(*self.shape) < 10): # If matrix small use the dense version
106 reduced, pivots = self.rref(simplify=simplify)
107 else:
108 reduced, pivots = self.rref_sparse(simplify=simplify)
110 free_vars = [i for i in range(self.cols) if i not in pivots]
112 basis = []
113 for free_var in free_vars:
114 # for each free variable, we will set it to 1 and all others
115 # to 0. Then, we will use back substitution to solve the system
117 # initialize each vector as an empty SparseMatrix
118 vec = self._new(self.cols, 1, 0)
119 vec[free_var] = 1
120 for piv_row, piv_col in enumerate(pivots):
121 vec[piv_col] -= reduced[piv_row, free_var]
122 basis.append(vec)
124 return basis
126 def to_array(self):
127 """ Cast SparseMatrix instance to numpy array. """
128 row, col, data = [], [], []
129 for r, c, v in self.row_list():
130 row.append(r)
131 col.append(c)
132 data.append(np.float64(v))
133 M = coo_matrix((data, (row, col)), shape=self.shape)
134 return M.toarray()
137class BiMap:
138 """Simple list like structure with fast dict-lookup.
140 The structure can append objects and supports some simple list interfaces.
141 The lookup is fast since an internal dict stores the indices.
142 """
144 def __init__(self):
145 self._list = list()
146 self._dict = dict()
148 def __contains__(self, value):
149 return value in self._dict
151 def __getitem__(self, index):
152 return self._list[index]
154 def append(self, value):
155 """bm.append(hashable) -> None -- append hashable object to end"""
156 self._dict[value] = len(self._list)
157 self._list.append(value)
159 def __len__(self):
160 return len(self._list)
162 def index(self, value) -> int:
163 """Returns index of value.
164 Raises ValueError if the value is not present.
165 """
166 try:
167 return self._dict[value]
168 except KeyError:
169 raise ValueError('{} is not in list'.format(value))
171 def copy(self):
172 """ A shallow copy of the BiMap. """
173 tbm = BiMap()
174 tbm._list = self._list.copy()
175 tbm._dict = self._dict.copy()
176 return tbm