Coverage for hiphive/utilities.py: 97%

121 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-01 17:04 +0000

1""" 

2This module contains various support/utility functions. 

3""" 

4 

5from typing import List, Tuple 

6import numpy as np 

7 

8from ase import Atoms 

9from ase.calculators.singlepoint import SinglePointCalculator 

10from ase.geometry import find_mic 

11from ase.geometry import get_distances 

12from ase.neighborlist import neighbor_list 

13from .cluster_space import ClusterSpace 

14from .force_constants import ForceConstants 

15from .input_output.logging_tools import logger 

16 

17 

18logger = logger.getChild('utilities') 

19 

20 

21def get_displacements(atoms: Atoms, 

22 atoms_ideal: Atoms, 

23 cell_tol: float = 1e-4) -> np.ndarray: 

24 """Returns the the smallest possible displacements between a 

25 displaced configuration relative to an ideal (reference) 

26 configuration. 

27 

28 Notes 

29 ----- 

30 * uses :func:`ase.geometry.find_mic` 

31 * assumes periodic boundary conditions in all directions 

32 

33 Parameters 

34 ---------- 

35 atoms 

36 Configuration with displaced atoms. 

37 atoms_ideal 

38 Ideal configuration relative to which displacements are computed. 

39 cell_tol 

40 Cell tolerance; if cell missmatch more than tol value error is raised. 

41 """ 

42 if not np.array_equal(atoms.numbers, atoms_ideal.numbers): 

43 raise ValueError('Atomic numbers do not match.') 

44 if np.linalg.norm(atoms.cell - atoms_ideal.cell) > cell_tol: 

45 raise ValueError('Cells do not match.') 

46 

47 raw_position_diff = atoms.positions - atoms_ideal.positions 

48 wrapped_mic_displacements = find_mic(raw_position_diff, atoms_ideal.cell, pbc=True)[0] 

49 return wrapped_mic_displacements 

50 

51 

52def _get_forces_from_atoms(atoms: Atoms, calc=None) -> np.ndarray: 

53 """ Try to get forces from an atoms object """ 

54 

55 # Check if two calculators are available 

56 if atoms.calc is not None and calc is not None: 

57 raise ValueError('Atoms.calc is not None and calculator was provided.') 

58 

59 # If calculator is provided as argument 

60 if calc is not None: 

61 atoms_tmp = atoms.copy() 

62 atoms_tmp.calc = calc 

63 forces_calc = atoms_tmp.get_forces() 

64 if 'forces' in atoms.arrays: 

65 if not np.allclose(forces_calc, atoms.get_array('forces')): 

66 raise ValueError('Forces in atoms.arrays are different from the calculator forces.') 

67 return forces_calc 

68 

69 # If calculator is attached 

70 if atoms.calc is not None: 

71 if not isinstance(atoms.calc, SinglePointCalculator): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 raise ValueError('atoms.calc is not a SinglePointCalculator.') 

73 

74 forces_calc = atoms.get_forces() 

75 if 'forces' in atoms.arrays: 

76 if not np.allclose(forces_calc, atoms.get_array('forces')): 

77 raise ValueError('Forces in atoms.arrays are different from the calculator forces.') 

78 return forces_calc 

79 

80 # No calculator attached or provided as argument, forces should therefore be in atoms.arrays 

81 if 'forces' in atoms.arrays: 

82 forces = atoms.get_array('forces') 

83 else: 

84 raise ValueError('Unable to find forces.') 

85 return forces 

86 

87 

88def prepare_structure(atoms: Atoms, 

89 atoms_ideal: Atoms, 

90 calc: SinglePointCalculator = None, 

91 check_permutation: bool = True) -> Atoms: 

92 """Prepare a structure in the format suitable for a 

93 :class:`StructureContainer <hiphive.StructureContainer>`. 

94 

95 Either forces should be attached to input atoms object as an array, 

96 or the atoms object should have a SinglePointCalculator attached to it containing 

97 forces, or a calculator (calc) should be supplied. 

98 

99 Parameters 

100 ---------- 

101 atoms 

102 Input structure. 

103 atoms_ideal 

104 Reference structure relative to which displacements are computed. 

105 check_permutation 

106 Whether :func:`find_permutation` should be used or not. 

107 calc 

108 ASE calculator to use for computing forces. 

109 

110 Returns 

111 ------- 

112 Prepared ASE :class:`Atoms` object with forces and displacements as arrays. 

113 """ 

114 # get forces 

115 forces = _get_forces_from_atoms(atoms, calc=calc) 

116 

117 # setup new atoms 

118 if check_permutation: 

119 perm = find_permutation(atoms, atoms_ideal) 

120 else: 

121 perm = np.array([f for f in range(len(atoms))]) 

122 atoms_new = atoms.copy() 

123 atoms_new = atoms_new[perm] 

124 atoms_new.arrays['forces'] = forces[perm] 

125 disps = get_displacements(atoms_new, atoms_ideal) 

126 atoms_new.arrays['displacements'] = disps 

127 atoms_new.positions = atoms_ideal.positions 

128 

129 return atoms_new 

130 

131 

132def prepare_structures(structures: List[Atoms], 

133 atoms_ideal: Atoms, 

134 calc: SinglePointCalculator = None, 

135 check_permutation: bool = True) -> List[Atoms]: 

136 """Prepares a set of structures in the format suitable for adding them to 

137 a :class:`StructureContainer <hiphive.StructureContainer>`. 

138 

139 :attr:`structures` should represent a list of supercells with displacements 

140 while `atoms_ideal` should provide the ideal reference structure (without 

141 displacements) for the given structures. 

142 

143 The structures that are returned will have their positions reset to the 

144 ideal structures. Displacements and forces will be added as arrays to the 

145 atoms objects. 

146 

147 If no calculator is provided, then there must be an ASE 

148 :class:`SinglePointCalculator <ase.calculators.singlepoint>` object attached to 

149 the structures or the forces should already be attached as 

150 arrays to the structures. 

151 

152 If a calculator is provided then it will be used to compute the forces for 

153 all structures. 

154 

155 Example 

156 ------- 

157 

158 The following example illustrates the use of this function:: 

159 

160 db = connect('dft_training_structures.db') 

161 training_structures = [row.toatoms() for row in db.select()] 

162 training_structures = prepare_structures(training_structures, atoms_ideal) 

163 for s in training_structures: 

164 sc.add_structure(s) 

165 

166 Parameters 

167 ---------- 

168 structures 

169 List of input displaced structures. 

170 atoms_ideal 

171 Reference structure relative to which displacements are computed. 

172 calc 

173 ASE calculator to use for computing forces. 

174 

175 Returns 

176 ------- 

177 List of prepared structures with forces and displacements as arrays. 

178 """ 

179 return [prepare_structure(s, atoms_ideal, calc, check_permutation) for s in structures] 

180 

181 

182def find_permutation(atoms: Atoms, atoms_ref: Atoms) -> List[int]: 

183 """ Returns the best permutation of atoms for mapping one 

184 configuration onto another. 

185 

186 Parameters 

187 ---------- 

188 atoms 

189 configuration to be permuted 

190 atoms_ref 

191 configuration onto which to map 

192 

193 Examples 

194 -------- 

195 After obtaining the permutation via ``p = find_permutation(atoms1, atoms2)`` 

196 the reordered structure ``atoms1[p]`` will give the closest match 

197 to ``atoms2``. 

198 """ 

199 assert np.linalg.norm(atoms.cell - atoms_ref.cell) < 1e-6 

200 permutation = [] 

201 for i in range(len(atoms_ref)): 

202 dist_row = get_distances( 

203 atoms.positions, atoms_ref.positions[i], cell=atoms_ref.cell, pbc=True)[1][:, 0] 

204 permutation.append(np.argmin(dist_row)) 

205 

206 if len(set(permutation)) != len(permutation): 

207 raise Exception('Duplicates in permutation') 

208 for i, p in enumerate(permutation): 

209 if atoms[p].symbol != atoms_ref[i].symbol: 

210 raise Exception('Matching lattice sites have different occupation') 

211 return permutation 

212 

213 

214class Shell: 

215 """ 

216 Neighbor Shell class 

217 

218 Parameters 

219 ---------- 

220 types : Union[list, tuple] 

221 Atomic types for neighbor shell. 

222 distance : float 

223 Interatomic distance for neighbor shell. 

224 count : int 

225 Number of pairs in the neighbor shell. 

226 """ 

227 

228 def __init__(self, 

229 types: List[str], 

230 distance: float, 

231 count: int = 0): 

232 self.types = types 

233 self.distance = distance 

234 self.count = count 

235 

236 def __str__(self): 

237 s = '{}-{} distance: {:10.6f} count: {}'.format(*self.types, self.distance, self.count) 

238 return s 

239 

240 __repr__ = __str__ 

241 

242 

243def get_neighbor_shells(atoms: Atoms, 

244 cutoff: float, 

245 dist_tol: float = 1e-5) -> List[Shell]: 

246 """ Returns a list of neighbor shells. 

247 

248 Distances are grouped into shells via the following algorithm: 

249 

250 1. Find smallest atomic distance `d_min` 

251 

252 2. Find all pair distances in the range `d_min + 1 * dist_tol` 

253 

254 3. Construct a shell from these and pop them from distance list 

255 

256 4. Go to 1. 

257 

258 Parameters 

259 ---------- 

260 atoms 

261 Configuration used for finding shells. 

262 cutoff 

263 Exclude neighbor shells which have a distance larger than this value. 

264 dist_tol 

265 Distance tolerance. 

266 """ 

267 

268 # get distances 

269 ijd = neighbor_list('ijd', atoms, cutoff) 

270 ijd = list(zip(*ijd)) 

271 ijd.sort(key=lambda x: x[2]) 

272 

273 # sort into shells 

274 symbols = atoms.get_chemical_symbols() 

275 shells = [] 

276 for i, j, d in ijd: 

277 types = tuple(sorted([symbols[i], symbols[j]])) 

278 for shell in shells: 

279 if abs(d - shell.distance) < dist_tol and types == shell.types: 

280 shell.count += 1 

281 break 

282 else: 

283 shell = Shell(types, d, 1) 

284 shells.append(shell) 

285 shells.sort(key=lambda x: (x.distance, x.types, x.count)) 

286 

287 # warning if two shells are within 2 * tol 

288 for i, s1 in enumerate(shells): 

289 for j, s2 in enumerate(shells[i+1:]): 

290 if s1.types != s2.types: 

291 continue 

292 if not s1.distance < s2.distance - 2 * dist_tol: 

293 logger.warning('Found two shells within 2 * dist_tol') 

294 

295 return shells 

296 

297 

298def extract_parameters(fcs: ForceConstants, 

299 cs: ClusterSpace, 

300 sanity_check: bool = True, 

301 lstsq_method: str = 'numpy') \ 

302 -> Tuple[np.ndarray, np.ndarray, int, np.ndarray]: 

303 """ Extracts parameters from force constants. 

304 

305 

306 This function can be used to extract parameters to create a 

307 ForceConstantPotential from a known set of force constants. 

308 The return values come from NumPy's `lstsq function 

309 <https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html>`_ 

310 or from SciPy's `sparse lsqr function 

311 <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html>`_. 

312 Using `lstsq_method='scipy'` might be faster and have a smaller memory footprint for large 

313 systems, at the expense of some accuracy. This is due to the use of sparse matrices 

314 and an iterative solver. 

315 

316 Parameters 

317 ---------- 

318 fcs 

319 Force constants. 

320 cs 

321 Cluster space. 

322 sanity_check 

323 Bool whether or not to perform a sanity check by computing the relative error between 

324 the input fcs and the output FCs. 

325 lstsq_method 

326 Method to use when making a least-squares fit of a :class:`ForceConstantModel` to the 

327 given FCs, allowed values are `'numpy'` for :func:`np.linalg.lstsq` 

328 and `'scipy'` for :func:`scipy.sparse.linalg.lsqr`. 

329 

330 Returns 

331 ------- 

332 Parameters that together with the ClusterSpace give the best representation of the FCs. 

333 """ 

334 from .force_constant_model import ForceConstantModel 

335 from .force_constant_potential import ForceConstantPotential 

336 from scipy.sparse.linalg import lsqr 

337 

338 if lstsq_method not in ['numpy', 'scipy']: 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true

339 raise ValueError('lstsq_method must be either numpy or scipy') 

340 

341 # extract the parameters 

342 fcm = ForceConstantModel(fcs.supercell, cs) 

343 # If the cluster space large, a sparse least squares solver is faster 

344 if lstsq_method == 'numpy': 

345 A, b = fcm.get_fcs_sensing(fcs, sparse=False) 

346 parameters = np.linalg.lstsq(A, b, rcond=None)[0] 

347 elif lstsq_method == 'scipy': 347 ↛ 353line 347 didn't jump to line 353 because the condition on line 347 was always true

348 A, b = fcm.get_fcs_sensing(fcs, sparse=True) 

349 # set minimal tolerances to maximize iterative least squares accuracy 

350 parameters = lsqr(A, b, atol=0, btol=0, conlim=0)[0] 

351 

352 # calculate the relative force constant error 

353 if sanity_check: 353 ↛ 362line 353 didn't jump to line 362 because the condition on line 353 was always true

354 fcp = ForceConstantPotential(cs, parameters) 

355 fcs_hiphive = fcp.get_force_constants(fcs.supercell) 

356 for order in cs.cutoffs.orders: 

357 fc_original = fcs.get_fc_array(order=order) 

358 fc_reconstructed = fcs_hiphive.get_fc_array(order=order) 

359 rel_error = np.linalg.norm(fc_original-fc_reconstructed) / np.linalg.norm(fc_original) 

360 print(f'Force constant reconstruction error order {order}: {100*rel_error:9.4f}%') 

361 

362 return parameters