Coverage for hiphive/core/utilities.py: 99%

109 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-01 17:04 +0000

1""" 

2The ``utilities`` module contains various support functions and classes. 

3""" 

4import sympy 

5import numpy as np 

6from scipy.sparse import coo_matrix 

7from collections import defaultdict 

8from ase import Atoms 

9 

10 

11def ase_atoms_to_spglib_tuple(structure: Atoms) -> tuple: 

12 """ 

13 Returns a tuple of three components: cell metric, atomic positions, and 

14 atomic species of the input ASE Atoms object. 

15 """ 

16 return (structure.cell, structure.get_scaled_positions(), structure.get_atomic_numbers()) 

17 

18 

19class SparseMatrix(sympy.SparseMatrix): 

20 

21 def rref_sparse(self, simplify=False): 

22 """ A sparse row reduce algorithm mimicing the dense version. """ 

23 

24 def multiply_and_add_row(row1, factor, row2): 

25 """ Does row1 += factor * row2 when rows are represented by dicts. 

26 """ 

27 keys_to_delete = [] 

28 for k, v in row2.items(): 

29 row1[k] += factor * v 

30 if row1[k] == 0: 

31 keys_to_delete.append(k) 

32 for k in keys_to_delete: 

33 del row1[k] 

34 

35 # The matrix is represented as a list of dicts where each dict is a row 

36 M = defaultdict(lambda: defaultdict(lambda: 0)) 

37 

38 # Init our special representation. This is possible due to defaultdict 

39 for r, c, v in self.row_list(): 

40 M[r][c] = v 

41 M = list(M.values()) 

42 

43 # The pivot elements is stored as tuples as (row, col) 

44 pivots = [] 

45 r, c = 0, 0 # current row and col of possible pivot 

46 nRows, nCols = len(M), self.shape[1] 

47 while r < nRows and c < nCols: 

48 row = M[r] 

49 # check if proposed pivot i zero. if so swap this row with a row 

50 # below which has non zero element at that col 

51 if c not in row: 

52 for r2 in range(r + 1, nRows): 

53 row2 = M[r2] 

54 if c in row2: # This row has element in the current col 

55 M[r], M[r2] = row2, row # swap the rows 

56 row = row2 

57 break 

58 else: # The pivot and all elements below are zero. 

59 c += 1 # goto next column but stay on row 

60 continue 

61 pivots.append((r, c)) 

62 # Normalize row 

63 row_c = row[c] 

64 for k in row: 

65 row[k] /= row_c 

66 # Start elimination 

67 r2 = r + 1 

68 while r2 < nRows: 

69 row2 = M[r2] 

70 if c not in row2: 

71 r2 += 1 

72 continue 

73 multiply_and_add_row(row2, -row2[c], row) 

74 if len(row2) == 0: 

75 nRows -= 1 

76 del M[r2] 

77 continue 

78 r2 += 1 

79 r += 1 

80 c += 1 

81 

82 # Eliminate elements above pivots 

83 for (r, p) in pivots: 

84 row_p = M[r] 

85 for i in range(r): 

86 row = M[i] 

87 if p in row: 

88 multiply_and_add_row(row, -row[p], row_p) 

89 

90 # Create the new rrefd matrix 

91 M2 = SparseMatrix(*self.shape, 0) 

92 for i, d in enumerate(M): 

93 for j, v in d.items(): 

94 M2[i, j] = v 

95 

96 pivots = tuple(p[1] for p in pivots) 

97 

98 return M2, pivots 

99 

100 def nullspace(self, simplify=False): 

101 """ This is a slightly patched version which also uses the sparse rref 

102 and is faster due to up-front creation of empty SparseMatrix 

103 vectors instead of conversion of the finished vectors. 

104 """ 

105 if (max(*self.shape) < 10): # If matrix small use the dense version 

106 reduced, pivots = self.rref(simplify=simplify) 

107 else: 

108 reduced, pivots = self.rref_sparse(simplify=simplify) 

109 

110 free_vars = [i for i in range(self.cols) if i not in pivots] 

111 

112 basis = [] 

113 for free_var in free_vars: 

114 # for each free variable, we will set it to 1 and all others 

115 # to 0. Then, we will use back substitution to solve the system 

116 

117 # initialize each vector as an empty SparseMatrix 

118 vec = self._new(self.cols, 1, 0) 

119 vec[free_var] = 1 

120 for piv_row, piv_col in enumerate(pivots): 

121 vec[piv_col] -= reduced[piv_row, free_var] 

122 basis.append(vec) 

123 

124 return basis 

125 

126 def to_array(self): 

127 """ Cast SparseMatrix instance to numpy array. """ 

128 row, col, data = [], [], [] 

129 for r, c, v in self.row_list(): 

130 row.append(r) 

131 col.append(c) 

132 data.append(np.float64(v)) 

133 M = coo_matrix((data, (row, col)), shape=self.shape) 

134 return M.toarray() 

135 

136 

137class BiMap: 

138 """Simple list like structure with fast dict-lookup. 

139 

140 The structure can append objects and supports some simple list interfaces. 

141 The lookup is fast since an internal dict stores the indices. 

142 """ 

143 

144 def __init__(self): 

145 self._list = list() 

146 self._dict = dict() 

147 

148 def __contains__(self, value): 

149 return value in self._dict 

150 

151 def __getitem__(self, index): 

152 return self._list[index] 

153 

154 def append(self, value): 

155 """bm.append(hashable) -> None -- append hashable object to end""" 

156 self._dict[value] = len(self._list) 

157 self._list.append(value) 

158 

159 def __len__(self): 

160 return len(self._list) 

161 

162 def index(self, value) -> int: 

163 """Returns index of value. 

164 Raises ValueError if the value is not present. 

165 """ 

166 try: 

167 return self._dict[value] 

168 except KeyError: 

169 raise ValueError('{} is not in list'.format(value)) 

170 

171 def copy(self): 

172 """ A shallow copy of the BiMap. """ 

173 tbm = BiMap() 

174 tbm._list = self._list.copy() 

175 tbm._dict = self._dict.copy() 

176 return tbm