Coverage for hiphive/core/utilities.py: 99%
112 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1"""
2The ``utilities`` module contains various support functions and classes.
3"""
4import sympy
5import numpy as np
6from scipy.sparse import coo_matrix
7from collections import defaultdict
8from ..input_output.logging_tools import Progress
9from ase import Atoms
10from typing import Tuple
12__all__ = ['Progress']
15def ase_atoms_to_spglib_tuple(structure: Atoms) -> Tuple:
16 """
17 Returns a tuple of three components: cell metric, atomic positions, and
18 atomic species of the input ASE Atoms object.
19 """
20 return (structure.cell, structure.get_scaled_positions(), structure.get_atomic_numbers())
23class SparseMatrix(sympy.SparseMatrix):
25 def rref_sparse(self, simplify=False):
26 """ A sparse row reduce algorithm mimicing the dense version """
28 def multiply_and_add_row(row1, factor, row2):
29 """ does row1 += factor * row2 when rows are represented by dicts
30 """
31 keys_to_delete = []
32 for k, v in row2.items():
33 row1[k] += factor * v
34 if row1[k] == 0:
35 keys_to_delete.append(k)
36 for k in keys_to_delete:
37 del row1[k]
39 # The matrix is represented as a list of dicts where each dict is a row
40 M = defaultdict(lambda: defaultdict(lambda: 0))
42 # Init our special representation. This is possible due to defaultdict
43 for r, c, v in self.row_list():
44 M[r][c] = v
45 M = list(M.values())
47 # The pivot elements is stored as tuples as (row, col)
48 pivots = []
49 r, c = 0, 0 # current row and col of possible pivot
50 nRows, nCols = len(M), self.shape[1]
51 while r < nRows and c < nCols:
52 row = M[r]
53 # check if proposed pivot i zero. if so swap this row with a row
54 # below which has non zero element at that col
55 if c not in row:
56 for r2 in range(r + 1, nRows):
57 row2 = M[r2]
58 if c in row2: # This row has element in the current col
59 M[r], M[r2] = row2, row # swap the rows
60 row = row2
61 break
62 else: # The pivot and all elements below are zero.
63 c += 1 # goto next column but stay on row
64 continue
65 pivots.append((r, c))
66 # Normalize row
67 row_c = row[c]
68 for k in row:
69 row[k] /= row_c
70 # Start elimination
71 r2 = r + 1
72 while r2 < nRows:
73 row2 = M[r2]
74 if c not in row2:
75 r2 += 1
76 continue
77 multiply_and_add_row(row2, -row2[c], row)
78 if len(row2) == 0:
79 nRows -= 1
80 del M[r2]
81 continue
82 r2 += 1
83 r += 1
84 c += 1
86 # Eliminate elements above pivots
87 for (r, p) in pivots:
88 row_p = M[r]
89 for i in range(r):
90 row = M[i]
91 if p in row:
92 multiply_and_add_row(row, -row[p], row_p)
94 # Create the new rrefd matrix
95 M2 = SparseMatrix(*self.shape, 0)
96 for i, d in enumerate(M):
97 for j, v in d.items():
98 M2[i, j] = v
100 pivots = tuple(p[1] for p in pivots)
102 return M2, pivots
104 def nullspace(self, simplify=False):
105 """ This is a slightly patched version which also uses the sparse rref
106 and is faster due to up-front creation of empty SparseMatrix
107 vectors instead of conversion of the finished vectors
108 """
109 if (max(*self.shape) < 10): # If matrix small use the dense version
110 reduced, pivots = self.rref(simplify=simplify)
111 else:
112 reduced, pivots = self.rref_sparse(simplify=simplify)
114 free_vars = [i for i in range(self.cols) if i not in pivots]
116 basis = []
117 for free_var in free_vars:
118 # for each free variable, we will set it to 1 and all others
119 # to 0. Then, we will use back substitution to solve the system
121 # initialize each vector as an empty SparseMatrix
122 vec = self._new(self.cols, 1, 0)
123 vec[free_var] = 1
124 for piv_row, piv_col in enumerate(pivots):
125 vec[piv_col] -= reduced[piv_row, free_var]
126 basis.append(vec)
128 return basis
130 def to_array(self):
131 """ Cast SparseMatrix instance to numpy array """
132 row, col, data = [], [], []
133 for r, c, v in self.row_list():
134 row.append(r)
135 col.append(c)
136 data.append(np.float64(v))
137 M = coo_matrix((data, (row, col)), shape=self.shape)
138 return M.toarray()
141class BiMap:
142 """Simple list like structure with fast dict-lookup
144 The structure can append objects and supports some simple list interfaces.
145 The lookup is fast since an internal dict stores the indices.
146 """
147 def __init__(self):
148 self._list = list()
149 self._dict = dict()
151 def __contains__(self, value):
152 return value in self._dict
154 def __getitem__(self, index):
155 return self._list[index]
157 def append(self, value):
158 """bm.append(hashable) -> None -- append hashable object to end"""
159 self._dict[value] = len(self._list)
160 self._list.append(value)
162 def __len__(self):
163 return len(self._list)
165 def index(self, value):
166 """bm.index(values) -> integer -- return index of value
167 Raises ValueError if the values is not present.
168 """
169 try:
170 return self._dict[value]
171 except KeyError:
172 raise ValueError('{} is not in list'.format(value))
174 def copy(self):
175 """bm.copy() -> BiMap -- a shallow copy of bm"""
176 tbm = BiMap()
177 tbm._list = self._list.copy()
178 tbm._dict = self._dict.copy()
179 return tbm