Coverage for hiphive/core/utilities.py: 99%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1""" 

2The ``utilities`` module contains various support functions and classes. 

3""" 

4import sympy 

5import numpy as np 

6from scipy.sparse import coo_matrix 

7from collections import defaultdict 

8from ..input_output.logging_tools import Progress 

9from ase import Atoms 

10from typing import Tuple 

11 

12__all__ = ['Progress'] 

13 

14 

15def ase_atoms_to_spglib_tuple(structure: Atoms) -> Tuple: 

16 """ 

17 Returns a tuple of three components: cell metric, atomic positions, and 

18 atomic species of the input ASE Atoms object. 

19 """ 

20 return (structure.cell, structure.get_scaled_positions(), structure.get_atomic_numbers()) 

21 

22 

23class SparseMatrix(sympy.SparseMatrix): 

24 

25 def rref_sparse(self, simplify=False): 

26 """ A sparse row reduce algorithm mimicing the dense version """ 

27 

28 def multiply_and_add_row(row1, factor, row2): 

29 """ does row1 += factor * row2 when rows are represented by dicts 

30 """ 

31 keys_to_delete = [] 

32 for k, v in row2.items(): 

33 row1[k] += factor * v 

34 if row1[k] == 0: 

35 keys_to_delete.append(k) 

36 for k in keys_to_delete: 

37 del row1[k] 

38 

39 # The matrix is represented as a list of dicts where each dict is a row 

40 M = defaultdict(lambda: defaultdict(lambda: 0)) 

41 

42 # Init our special representation. This is possible due to defaultdict 

43 for r, c, v in self.row_list(): 

44 M[r][c] = v 

45 M = list(M.values()) 

46 

47 # The pivot elements is stored as tuples as (row, col) 

48 pivots = [] 

49 r, c = 0, 0 # current row and col of possible pivot 

50 nRows, nCols = len(M), self.shape[1] 

51 while r < nRows and c < nCols: 

52 row = M[r] 

53 # check if proposed pivot i zero. if so swap this row with a row 

54 # below which has non zero element at that col 

55 if c not in row: 

56 for r2 in range(r + 1, nRows): 

57 row2 = M[r2] 

58 if c in row2: # This row has element in the current col 

59 M[r], M[r2] = row2, row # swap the rows 

60 row = row2 

61 break 

62 else: # The pivot and all elements below are zero. 

63 c += 1 # goto next column but stay on row 

64 continue 

65 pivots.append((r, c)) 

66 # Normalize row 

67 row_c = row[c] 

68 for k in row: 

69 row[k] /= row_c 

70 # Start elimination 

71 r2 = r + 1 

72 while r2 < nRows: 

73 row2 = M[r2] 

74 if c not in row2: 

75 r2 += 1 

76 continue 

77 multiply_and_add_row(row2, -row2[c], row) 

78 if len(row2) == 0: 

79 nRows -= 1 

80 del M[r2] 

81 continue 

82 r2 += 1 

83 r += 1 

84 c += 1 

85 

86 # Eliminate elements above pivots 

87 for (r, p) in pivots: 

88 row_p = M[r] 

89 for i in range(r): 

90 row = M[i] 

91 if p in row: 

92 multiply_and_add_row(row, -row[p], row_p) 

93 

94 # Create the new rrefd matrix 

95 M2 = SparseMatrix(*self.shape, 0) 

96 for i, d in enumerate(M): 

97 for j, v in d.items(): 

98 M2[i, j] = v 

99 

100 pivots = tuple(p[1] for p in pivots) 

101 

102 return M2, pivots 

103 

104 def nullspace(self, simplify=False): 

105 """ This is a slightly patched version which also uses the sparse rref 

106 and is faster due to up-front creation of empty SparseMatrix 

107 vectors instead of conversion of the finished vectors 

108 """ 

109 if (max(*self.shape) < 10): # If matrix small use the dense version 

110 reduced, pivots = self.rref(simplify=simplify) 

111 else: 

112 reduced, pivots = self.rref_sparse(simplify=simplify) 

113 

114 free_vars = [i for i in range(self.cols) if i not in pivots] 

115 

116 basis = [] 

117 for free_var in free_vars: 

118 # for each free variable, we will set it to 1 and all others 

119 # to 0. Then, we will use back substitution to solve the system 

120 

121 # initialize each vector as an empty SparseMatrix 

122 vec = self._new(self.cols, 1, 0) 

123 vec[free_var] = 1 

124 for piv_row, piv_col in enumerate(pivots): 

125 vec[piv_col] -= reduced[piv_row, free_var] 

126 basis.append(vec) 

127 

128 return basis 

129 

130 def to_array(self): 

131 """ Cast SparseMatrix instance to numpy array """ 

132 row, col, data = [], [], [] 

133 for r, c, v in self.row_list(): 

134 row.append(r) 

135 col.append(c) 

136 data.append(np.float64(v)) 

137 M = coo_matrix((data, (row, col)), shape=self.shape) 

138 return M.toarray() 

139 

140 

141class BiMap: 

142 """Simple list like structure with fast dict-lookup 

143 

144 The structure can append objects and supports some simple list interfaces. 

145 The lookup is fast since an internal dict stores the indices. 

146 """ 

147 def __init__(self): 

148 self._list = list() 

149 self._dict = dict() 

150 

151 def __contains__(self, value): 

152 return value in self._dict 

153 

154 def __getitem__(self, index): 

155 return self._list[index] 

156 

157 def append(self, value): 

158 """bm.append(hashable) -> None -- append hashable object to end""" 

159 self._dict[value] = len(self._list) 

160 self._list.append(value) 

161 

162 def __len__(self): 

163 return len(self._list) 

164 

165 def index(self, value): 

166 """bm.index(values) -> integer -- return index of value 

167 Raises ValueError if the values is not present. 

168 """ 

169 try: 

170 return self._dict[value] 

171 except KeyError: 

172 raise ValueError('{} is not in list'.format(value)) 

173 

174 def copy(self): 

175 """bm.copy() -> BiMap -- a shallow copy of bm""" 

176 tbm = BiMap() 

177 tbm._list = self._list.copy() 

178 tbm._dict = self._dict.copy() 

179 return tbm