Coverage for hiphive/utilities.py: 97%

121 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1""" 

2This module contains various support/utility functions. 

3""" 

4 

5from typing import List, Tuple 

6import numpy as np 

7 

8from ase import Atoms 

9from ase.calculators.singlepoint import SinglePointCalculator 

10from ase.geometry import find_mic 

11from ase.geometry import get_distances 

12from ase.neighborlist import neighbor_list 

13from .cluster_space import ClusterSpace 

14from .force_constants import ForceConstants 

15from .input_output.logging_tools import logger 

16 

17 

18logger = logger.getChild('utilities') 

19 

20 

21def get_displacements(atoms: Atoms, 

22 atoms_ideal: Atoms, 

23 cell_tol: float = 1e-4) -> np.ndarray: 

24 """Returns the the smallest possible displacements between a 

25 displaced configuration relative to an ideal (reference) 

26 configuration. 

27 

28 Notes 

29 ----- 

30 * uses :func:`ase.geometry.find_mic` 

31 * assumes periodic boundary conditions in all directions 

32 

33 Parameters 

34 ---------- 

35 atoms 

36 configuration with displaced atoms 

37 atoms_ideal 

38 ideal configuration relative to which displacements are computed 

39 cell_tol 

40 cell tolerance; if cell missmatch more than tol value error is raised 

41 """ 

42 if not np.array_equal(atoms.numbers, atoms_ideal.numbers): 

43 raise ValueError('Atomic numbers do not match.') 

44 if np.linalg.norm(atoms.cell - atoms_ideal.cell) > cell_tol: 

45 raise ValueError('Cells do not match.') 

46 

47 raw_position_diff = atoms.positions - atoms_ideal.positions 

48 wrapped_mic_displacements = find_mic(raw_position_diff, atoms_ideal.cell, pbc=True)[0] 

49 return wrapped_mic_displacements 

50 

51 

52def _get_forces_from_atoms(atoms: Atoms, calc=None) -> np.ndarray: 

53 """ Try to get forces from an atoms object """ 

54 

55 # Check if two calculators are available 

56 if atoms.calc is not None and calc is not None: 

57 raise ValueError('Atoms.calc is not None and calculator was provided') 

58 

59 # If calculator is provided as argument 

60 if calc is not None: 

61 atoms_tmp = atoms.copy() 

62 atoms_tmp.calc = calc 

63 forces_calc = atoms_tmp.get_forces() 

64 if 'forces' in atoms.arrays: 

65 if not np.allclose(forces_calc, atoms.get_array('forces')): 

66 raise ValueError('Forces in atoms.arrays are different from the calculator forces') 

67 return forces_calc 

68 

69 # If calculator is attached 

70 if atoms.calc is not None: 

71 if not isinstance(atoms.calc, SinglePointCalculator): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true

72 raise ValueError('atoms.calc is not a SinglePointCalculator') 

73 

74 forces_calc = atoms.get_forces() 

75 if 'forces' in atoms.arrays: 

76 if not np.allclose(forces_calc, atoms.get_array('forces')): 

77 raise ValueError('Forces in atoms.arrays are different from the calculator forces') 

78 return forces_calc 

79 

80 # No calculator attached or provided as argument, forces should therefore be in atoms.arrays 

81 if 'forces' in atoms.arrays: 

82 forces = atoms.get_array('forces') 

83 else: 

84 raise ValueError('Unable to find forces') 

85 return forces 

86 

87 

88def prepare_structure(atoms: Atoms, 

89 atoms_ideal: Atoms, 

90 calc: SinglePointCalculator = None, 

91 check_permutation: bool = True) -> Atoms: 

92 """Prepare a structure in the format suitable for a 

93 :class:`StructureContainer <hiphive.StructureContainer>`. 

94 

95 Either forces should be attached to input atoms object as an array, 

96 or the atoms object should have a SinglePointCalculator attached to it containing 

97 forces, or a calculator (calc) should be supplied. 

98 

99 Parameters 

100 ---------- 

101 atoms 

102 input structure 

103 atoms_ideal 

104 reference structure relative to which displacements are computed 

105 check_permutation 

106 whether find_permutation should be used or not 

107 calc 

108 ASE calculator used for computing forces 

109 

110 Returns 

111 ------- 

112 ASE atoms object 

113 prepared ASE atoms object with forces and displacements as arrays 

114 """ 

115 # get forces 

116 forces = _get_forces_from_atoms(atoms, calc=calc) 

117 

118 # setup new atoms 

119 if check_permutation: 

120 perm = find_permutation(atoms, atoms_ideal) 

121 else: 

122 perm = np.array([f for f in range(len(atoms))]) 

123 atoms_new = atoms.copy() 

124 atoms_new = atoms_new[perm] 

125 atoms_new.arrays['forces'] = forces[perm] 

126 disps = get_displacements(atoms_new, atoms_ideal) 

127 atoms_new.arrays['displacements'] = disps 

128 atoms_new.positions = atoms_ideal.positions 

129 

130 return atoms_new 

131 

132 

133def prepare_structures(structures: List[Atoms], 

134 atoms_ideal: Atoms, 

135 calc: SinglePointCalculator = None, 

136 check_permutation: bool = True) -> List[Atoms]: 

137 """Prepares a set of structures in the format suitable for adding them to 

138 a :class:`StructureContainer <hiphive.StructureContainer>`. 

139 

140 `structures` should represent a list of supercells with displacements 

141 while `atoms_ideal` should provide the ideal reference structure (without 

142 displacements) for the given structures. 

143 

144 The structures that are returned will have their positions reset to the 

145 ideal structures. Displacements and forces will be added as arrays to the 

146 atoms objects. 

147 

148 If no calculator is provided, then there must be an ASE 

149 `SinglePointCalculator <ase.calculators.singlepoint>` object attached to 

150 the structures or the forces should already be attached as 

151 arrays to the structures. 

152 

153 If a calculator is provided then it will be used to compute the forces for 

154 all structures. 

155 

156 Example 

157 ------- 

158 

159 The following example illustrates the use of this function:: 

160 

161 db = connect('dft_training_structures.db') 

162 training_structures = [row.toatoms() for row in db.select()] 

163 training_structures = prepare_structures(training_structures, atoms_ideal) 

164 for s in training_structures: 

165 sc.add_structure(s) 

166 

167 Parameters 

168 ---------- 

169 structures 

170 list of input displaced structures 

171 atoms_ideal 

172 reference structure relative to which displacements are computed 

173 calc 

174 ASE calculator used for computing forces 

175 

176 Returns 

177 ------- 

178 list of prepared structures with forces and displacements as arrays 

179 """ 

180 return [prepare_structure(s, atoms_ideal, calc, check_permutation) for s in structures] 

181 

182 

183def find_permutation(atoms: Atoms, atoms_ref: Atoms) -> List[int]: 

184 """ Returns the best permutation of atoms for mapping one 

185 configuration onto another. 

186 

187 Parameters 

188 ---------- 

189 atoms 

190 configuration to be permuted 

191 atoms_ref 

192 configuration onto which to map 

193 

194 Examples 

195 -------- 

196 After obtaining the permutation via ``p = find_permutation(atoms1, atoms2)`` 

197 the reordered structure ``atoms1[p]`` will give the closest match 

198 to ``atoms2``. 

199 """ 

200 assert np.linalg.norm(atoms.cell - atoms_ref.cell) < 1e-6 

201 permutation = [] 

202 for i in range(len(atoms_ref)): 

203 dist_row = get_distances( 

204 atoms.positions, atoms_ref.positions[i], cell=atoms_ref.cell, pbc=True)[1][:, 0] 

205 permutation.append(np.argmin(dist_row)) 

206 

207 if len(set(permutation)) != len(permutation): 

208 raise Exception('Duplicates in permutation') 

209 for i, p in enumerate(permutation): 

210 if atoms[p].symbol != atoms_ref[i].symbol: 

211 raise Exception('Matching lattice sites have different occupation') 

212 return permutation 

213 

214 

215class Shell: 

216 """ 

217 Neighbor Shell class 

218 

219 Parameters 

220 ---------- 

221 types : list or tuple 

222 atomic types for neighbor shell 

223 distance : float 

224 interatomic distance for neighbor shell 

225 count : int 

226 number of pairs in the neighbor shell 

227 """ 

228 

229 def __init__(self, 

230 types: List[str], 

231 distance: float, 

232 count: int = 0): 

233 self.types = types 

234 self.distance = distance 

235 self.count = count 

236 

237 def __str__(self): 

238 s = '{}-{} distance: {:10.6f} count: {}'.format(*self.types, self.distance, self.count) 

239 return s 

240 

241 __repr__ = __str__ 

242 

243 

244def get_neighbor_shells(atoms: Atoms, 

245 cutoff: float, 

246 dist_tol: float = 1e-5) -> List[Shell]: 

247 """ Returns a list of neighbor shells. 

248 

249 Distances are grouped into shells via the following algorithm: 

250 

251 1. Find smallest atomic distance `d_min` 

252 

253 2. Find all pair distances in the range `d_min + 1 * dist_tol` 

254 

255 3. Construct a shell from these and pop them from distance list 

256 

257 4. Go to 1. 

258 

259 Parameters 

260 ---------- 

261 atoms 

262 configuration used for finding shells 

263 cutoff 

264 exclude neighbor shells which have a distance larger than this value 

265 dist_tol 

266 distance tolerance 

267 """ 

268 

269 # get distances 

270 ijd = neighbor_list('ijd', atoms, cutoff) 

271 ijd = list(zip(*ijd)) 

272 ijd.sort(key=lambda x: x[2]) 

273 

274 # sort into shells 

275 symbols = atoms.get_chemical_symbols() 

276 shells = [] 

277 for i, j, d in ijd: 

278 types = tuple(sorted([symbols[i], symbols[j]])) 

279 for shell in shells: 

280 if abs(d - shell.distance) < dist_tol and types == shell.types: 

281 shell.count += 1 

282 break 

283 else: 

284 shell = Shell(types, d, 1) 

285 shells.append(shell) 

286 shells.sort(key=lambda x: (x.distance, x.types, x.count)) 

287 

288 # warning if two shells are within 2 * tol 

289 for i, s1 in enumerate(shells): 

290 for j, s2 in enumerate(shells[i+1:]): 

291 if s1.types != s2.types: 

292 continue 

293 if not s1.distance < s2.distance - 2 * dist_tol: 

294 logger.warning('Found two shells within 2 * dist_tol') 

295 

296 return shells 

297 

298 

299def extract_parameters(fcs: ForceConstants, 

300 cs: ClusterSpace, 

301 sanity_check: bool = True, 

302 lstsq_method: str = 'numpy') \ 

303 -> Tuple[np.ndarray, np.ndarray, int, np.ndarray]: 

304 """ Extracts parameters from force constants. 

305 

306 

307 This function can be used to extract parameters to create a 

308 ForceConstantPotential from a known set of force constants. 

309 The return values come from NumPy's `lstsq function 

310 <https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html>`_ 

311 or from SciPy's `sparse lsqr function 

312 <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html>`_. 

313 Using `lstsq_method='scipy'` might be faster and have a smaller memory footprint for large 

314 systems, at the expense of some accuracy. This is due to the use of sparse matrices 

315 and an iterative solver. 

316 

317 Parameters 

318 ---------- 

319 fcs 

320 force constants 

321 cs 

322 cluster space 

323 sanity_check 

324 bool whether or not to perform a sanity check by computing the relative error between 

325 the input fcs and the output fcs 

326 lstsq_method 

327 method to use when making a least squares fit of a ForceConstantModel to the given fcs, 

328 allowed values are 'numpy' for `np.linalg.lstsq` or 'scipy' `for scipy.sparse.linalg.lsqr` 

329 

330 Returns 

331 ------- 

332 parameters 

333 parameters that together with the ClusterSpace generates the best representation of the FCs 

334 

335 """ 

336 from .force_constant_model import ForceConstantModel 

337 from .force_constant_potential import ForceConstantPotential 

338 from scipy.sparse.linalg import lsqr 

339 

340 if lstsq_method not in ['numpy', 'scipy']: 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true

341 raise ValueError('lstsq_method must be either numpy or scipy') 

342 

343 # extract the parameters 

344 fcm = ForceConstantModel(fcs.supercell, cs) 

345 # If the cluster space large, a sparse least squares solver is faster 

346 if lstsq_method == 'numpy': 

347 A, b = fcm.get_fcs_sensing(fcs, sparse=False) 

348 parameters = np.linalg.lstsq(A, b, rcond=None)[0] 

349 elif lstsq_method == 'scipy': 349 ↛ 355line 349 didn't jump to line 355 because the condition on line 349 was always true

350 A, b = fcm.get_fcs_sensing(fcs, sparse=True) 

351 # set minimal tolerances to maximize iterative least squares accuracy 

352 parameters = lsqr(A, b, atol=0, btol=0, conlim=0)[0] 

353 

354 # calculate the relative force constant error 

355 if sanity_check: 355 ↛ 364line 355 didn't jump to line 364 because the condition on line 355 was always true

356 fcp = ForceConstantPotential(cs, parameters) 

357 fcs_hiphive = fcp.get_force_constants(fcs.supercell) 

358 for order in cs.cutoffs.orders: 

359 fc_original = fcs.get_fc_array(order=order) 

360 fc_reconstructed = fcs_hiphive.get_fc_array(order=order) 

361 rel_error = np.linalg.norm(fc_original-fc_reconstructed) / np.linalg.norm(fc_original) 

362 print(f'Force constant reconstruction error order {order}: {100*rel_error:9.4f}%') 

363 

364 return parameters