Coverage for hiphive/core/cluster_space_builder.py: 97%

247 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1import itertools 

2import numpy as np 

3import spglib as spg 

4from collections import Counter 

5from ase.neighborlist import NeighborList 

6 

7from .atoms import Atom, Atoms 

8from .clusters import get_clusters 

9from .orbits import get_orbits 

10from ..input_output.logging_tools import logger 

11from .eigentensors import create_eigentensors as _create_eigentensors 

12from .tensors import rotate_tensor_precalc, rotation_tensor_as_matrix 

13from .translational_constraints import create_constraint_map as _create_constraint_map 

14from .utilities import BiMap, ase_atoms_to_spglib_tuple 

15 

16 

17# TODO: Add longer description of each step 

18# TODO: Be careful with side effects 

19# TODO: Preferably the functions should not take the cs as input 

20def build_cluster_space(cluster_space, prototype_structure): 

21 

22 """ The permutation list is an indexed fast lookup table for permutation 

23 vectors. 

24 """ 

25 logger.debug('Populate permutation list') 

26 _create_permutations(cluster_space) 

27 

28 """ The primitive cell is calcualted by spglib and contains the cell 

29 metric, basis and atomic numbers. After this the prototype structure is 

30 disposed. 

31 """ 

32 logger.debug('Get primitive cell') 

33 _create_primitive_cell(cluster_space, prototype_structure) 

34 

35 """ The symmetries are calculated by spglib and the main information is the 

36 rotation matrices and translation vectors for each symmetry in scaled 

37 coordinates. 

38 """ 

39 logger.debug('Get symmetries') 

40 _create_symmetry_dataset(cluster_space) 

41 

42 """ The neigbor atoms to the center cell are stored in an indexed list. The 

43 list contains all atoms within the maximum cutoff specified. 

44 """ 

45 logger.debug('Find neighbors') 

46 _create_atom_list(cluster_space) 

47 

48 """ Clusters are generated as combinations of the indices of the atoms in 

49 the atom_list. If the cluster is to be included or not depends on the 

50 specification of valid clusters in the cutoffs object. Often all atoms must 

51 be within some distance from each other which may depend on both expansion 

52 order and number of atoms in the cluster 

53 """ 

54 logger.debug('Starting generating clusters') 

55 _create_cluster_list(cluster_space) 

56 logger.debug('Finished generating clusters') 

57 

58 """ The clusters are categorized into orbits and the eigensymmetries are 

59 stored in each orbit. 

60 """ 

61 logger.debug('Starting categorizing clusters') 

62 _create_orbits(cluster_space) 

63 logger.debug('Finished categorizing clusters') 

64 

65 """ The eigensymmetries from the previous step is used to generate valid 

66 eigentensors 

67 """ 

68 logger.debug('Starting finding eigentensors') 

69 _create_eigentensors(cluster_space) 

70 logger.debug('Finished finding eigentensors') 

71 

72 """ If some orbits can't have a force constant due to symmetry they are 

73 dropped from the _orbits attribute but kept in a separate list named 

74 _dropped_orbits 

75 """ 

76 logger.debug('Dropping orbits...') 

77 _drop_orbits(cluster_space) 

78 

79 """ Each orientation family gets populated with its rotated version of the 

80 orbits eigentensors. 

81 """ 

82 logger.debug('Rotating eigentensors into ofs...') 

83 _populate_ofs_with_ets(cluster_space) 

84 

85 """ The matrix describing the mapping which preserves the global symmetries 

86 is created (translational and rotational symmetry). 

87 """ 

88 logger.debug('Constructing constraint map') 

89 _create_constraint_map(cluster_space) 

90 logger.info('Constraints:') 

91 logger.info(' Acoustic: {}'.format(cluster_space.acoustic_sum_rules)) 

92 ndofs_by_order = {o: cluster_space.get_n_dofs_by_order(o) for o in cluster_space.cutoffs.orders} 

93 logger.info(' Number of degrees of freedom: {}'.format(ndofs_by_order)) 

94 for order, count in ndofs_by_order.items(): 

95 if count == 0: 

96 logger.warning(' Warning: No degrees of freedom exists for order {}'.format(order)) 

97 logger.info(' Total number of degrees of freedom: {}'.format(cluster_space.n_dofs)) 

98 

99 """ The eigentensors are rescaled depending on order to create a more well 

100 behaved system for fitting. NOTE! In here the constratints are rescaled too 

101 """ 

102 logger.debug('Rescale eigentensors') 

103 _rescale_eigentensors(cluster_space) 

104 

105 logger.debug('Normalize constraints') 

106 _normalize_constraint_vectors(cluster_space) 

107 

108 logger.debug('Rotate tensors to Carteesian coordinates') 

109 _rotate_eigentensors(cluster_space) 

110 

111 

112# TODO: Actual input could be just the maximum order 

113# TODO: No side effects, returns only the permutation BiMap 

114def _create_permutations(cs): 

115 orders = cs.cutoffs.orders 

116 permutations = BiMap() 

117 for order in orders: 

118 for permutation in itertools.permutations(range(order)): 

119 permutations.append(permutation) 

120 cs._permutations = permutations 

121 

122 

123# TODO: tolerances must be fixed in a coherent way. Prefarably via a config 

124# object 

125# TODO: Does the basis check need to be done? There might not be a problem that 

126# it is close to 1 instead of 0 anymore. If thats the case it is better to keep 

127# it as spglib returns it 

128# TODO: Add good debug 

129# TODO: Assert spos dot cell == pos. Sometimes the positions can be outside of 

130# the cell and then there is a mismatch between wath is returned by 

131# get_sclad_positions, positions and what spos dot cell gives (it should give 

132# the position 

133# TODO: Check basis to see if it can be represented by sympy. (in preparation 

134# for rotational sum rules) 

135# TODO: general function -> break out into utility function 

136# TODO: Send the tolerance as a parameter instef of the whole cs. 

137def _create_primitive_cell(cs, prototype_structure): 

138 

139 structure_as_tuple = ase_atoms_to_spglib_tuple(prototype_structure) 

140 spgPrim = spg.standardize_cell(structure_as_tuple, no_idealize=True, 

141 to_primitive=True, symprec=cs.symprec) 

142 

143 numbers_match = sorted(spgPrim[2]) == sorted(prototype_structure.numbers) 

144 spg_cell_volume = np.abs(np.linalg.det(spgPrim[0].T)) 

145 prototype_cell_volume = np.abs(np.linalg.det(prototype_structure.cell.T)) 

146 # TODO: is symprec the best tolerance to use for volume check? 

147 cell_volume_match = np.isclose(spg_cell_volume, prototype_cell_volume, atol=cs.symprec, rtol=0) 

148 

149 if numbers_match and cell_volume_match: 

150 prim = Atoms(prototype_structure) 

151 prim.wrap() 

152 else: 

153 basis = spgPrim[1] 

154 if np.any(basis > (1 - cs.symprec)): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true

155 logger.debug('Found basis close to 1:\n {}'.format(str(basis))) 

156 basis = basis.round(8) % 1 # TODO 

157 logger.debug('Wrapping to:\n {}'.format(str(basis))) 

158 prim = Atoms(cell=spgPrim[0], scaled_positions=basis, numbers=spgPrim[2], pbc=True) 

159 

160 # log primitive cell information 

161 logger.info('Primitive cell:') 

162 logger.info(' Formula: {}'.format(prim.get_chemical_formula())) 

163 logger.info((' Cell:' + '\n [{:9.5f} {:9.5f} {:9.5f}]'*3).format( 

164 *prim.cell[0], *prim.cell[1], *prim.cell[2])) 

165 logger.info(' Basis:') 

166 if len(prim) < 5: 166 ↛ 170line 166 didn't jump to line 170 because the condition on line 166 was always true

167 for symbol, spos in zip(prim.get_chemical_symbols(), prim.basis): 

168 logger.info(' {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(symbol, *spos)) 

169 else: 

170 for sym, spos in zip(prim[:3].get_chemical_symbols(), prim[:3].basis): 

171 logger.info(' {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(sym, *spos)) 

172 logger.info(' ...') 

173 logger.info('') 

174 cs._prim = prim 

175 

176 

177# TODO: Fix how the tolerance is handled 

178# TODO: Look over properties to acccess symmetry_dataset. Especially rotation, 

179# translation and wyckoff 

180# TODO: Sen prim and symprec as parameters 

181def _create_symmetry_dataset(cs): 

182 

183 prim_as_tuple = ase_atoms_to_spglib_tuple(cs.primitive_structure) 

184 symmetry_dataset = spg.get_symmetry_dataset(prim_as_tuple, symprec=cs.symprec) 

185 cs._symmetry_dataset = symmetry_dataset 

186 

187 logger.info('Crystal symmetry:') 

188 logger.info(' Spacegroup: {}'.format(cs.spacegroup)) 

189 logger.info(' Unique site: {}'.format(len(set(cs.wyckoff_sites)))) 

190 logger.info(' Symmetry operations: {}'.format(len(cs.rotation_matrices))) 

191 logger.info(' symprec: {:.2e}'.format(cs.symprec)) 

192 logger.info('') 

193 

194 

195# TODO: Fix how the tolerance is handled 

196# TODO: Refactor the two runs of finding the neighbors 

197# TODO: The bug that the cutoff is exactly on a shell might be a non issue. 

198# TODO: It is possible to check that the clusters map out the orbits completely 

199# TODO: Send in prim, cutoff and config and return atom_list instead 

200def _create_atom_list(cs): 

201 

202 tol = cs.symprec 

203 atom_list = BiMap() 

204 

205 # Populating the atom list with the center atoms 

206 for i in range(len(cs._prim)): 

207 atom_list.append(Atom(i, [0, 0, 0])) 

208 

209 logger.info('Cutoffs:') 

210 logger.info(' Maximum cutoff: {}'.format(cs.cutoffs.max_cutoff)) 

211 

212 # Find all the atoms which is neighbors to the atoms in the center cell 

213 # The pair cutoff should be larger or equal than the others 

214 cutoffs = [(cs.cutoffs.max_cutoff - tol) / 2] * len(cs._prim) 

215 nl = NeighborList(cutoffs=cutoffs, skin=0, self_interaction=True, bothways=True) 

216 nl.update(cs._prim) 

217 for i in range(len(cs._prim)): 

218 for index, offset in zip(*nl.get_neighbors(i)): 

219 atom = Atom(index, offset) 

220 if atom not in atom_list: 

221 atom_list.append(atom) 

222 

223 nl = NeighborList( 

224 cutoffs=[(cs.cutoffs.max_cutoff + tol) / 2] * len(cs._prim), 

225 skin=0, self_interaction=True, bothways=True) 

226 nl.update(cs._prim) 

227 distance_from_cutoff = tol 

228 for i in range(len(cs._prim)): 

229 for index, offset in zip(*nl.get_neighbors(i)): 

230 atom = Atom(index, offset) 

231 # ... and check that no new atom is found 

232 if atom not in atom_list: 

233 pos = atom.pos(cs._prim.basis, cs._prim.cell) 

234 distance = min(np.linalg.norm(pos - atom.position) 

235 for atom in cs._prim) - cs.cutoffs.max_cutoff 

236 distance_from_cutoff = min(distance, distance_from_cutoff) 

237 

238 if distance_from_cutoff != tol: 

239 raise Exception('Maximum cutoff close to neighbor shell, change cutoff') 

240 

241 msg = ' Found {} center atom{} with {} images totaling {} atoms'.format( 

242 len(cs._prim), 's' if len(cs._prim) > 1 else '', 

243 len(atom_list) - len(cs._prim), len(atom_list)) 

244 logger.info(msg) 

245 logger.info('') 

246 

247 cs._atom_list = atom_list 

248 

249 

250# TODO: add atoms property to cs 

251# TODO: Only inputs are prim, atom_list and cutoffs 

252def _create_cluster_list(cs): 

253 

254 # Convert the atom list from site/offset to scaled positions 

255 spos = [a.spos(cs._prim.basis) for a in cs._atom_list] 

256 numbers = [cs._prim.numbers[a.site] for a in cs._atom_list] 

257 

258 # Make an atoms object out of the scaled positions 

259 atoms = Atoms(cell=cs._prim.cell, scaled_positions=spos, numbers=numbers, pbc=False) 

260 

261 cs._cluster_filter.setup(atoms) 

262 cs._cluster_list = get_clusters(atoms, cs.cutoffs, len(cs._prim)) 

263 

264 logger.info('Clusters:') 

265 counter = Counter(len(c) for c in cs._cluster_list) 

266 logger.info(' Clusters: {}'.format(dict(counter))) 

267 logger.info(' Total number of clusters: {}\n'.format(sum(counter.values()))) 

268 

269 

270def _create_orbits(cs): 

271 # TODO: Check scaled/cart 

272 cs._orbits = get_orbits(cs._cluster_list, 

273 cs._atom_list, 

274 cs.rotation_matrices, 

275 cs.translation_vectors, 

276 cs.permutations, 

277 cs._prim, 

278 cs.symprec) 

279 orbits_to_drop = [] 

280 for i, orbit in enumerate(cs.orbits): 

281 if not cs._cluster_filter(cs._cluster_list[orbit.prototype_index]): 

282 orbits_to_drop.append(i) 

283 

284 reduced_orbits = [] 

285 cs._dropped_orbits = [] 

286 for i in range(len(cs.orbits)): 

287 if i in orbits_to_drop: 

288 cs._dropped_orbits.append(cs.orbits[i]) 

289 else: 

290 reduced_orbits.append(cs.orbits[i]) 

291 cs._orbits = reduced_orbits 

292 

293 counter = Counter(orb.order for orb in cs._orbits) 

294 logger.info('Orbits:') 

295 logger.info(' Orbits: {}'.format(dict(counter))) 

296 logger.info(' Total number of orbits: {}\n'.format(sum(counter.values()))) 

297 

298 

299def _drop_orbits(cs): 

300 orbits_to_drop = [] 

301 for i, orbit in enumerate(cs.orbits): 

302 if not orbit.eigentensors: 

303 orbits_to_drop.append(i) 

304 

305 reduced_orbits = [] 

306 

307 for i in range(len(cs.orbits)): 

308 if i in orbits_to_drop: 

309 cs._dropped_orbits.append(cs.orbits[i]) 

310 else: 

311 reduced_orbits.append(cs.orbits[i]) 

312 cs._orbits = reduced_orbits 

313 

314 logger.info('Eigentensors:') 

315 n_ets = dict() 

316 for order in cs.cutoffs.orders: 

317 n_ets[order] = sum(len(orb.eigentensors) for orb in cs.orbits if orb.order == order) 

318 logger.info(' Eigentensors: {}'.format(n_ets)) 

319 logger.info(' Total number of parameters: {}'.format(sum(n_ets.values()))) 

320 if len(cs._dropped_orbits) > 0: 

321 logger.info(' Discarded orbits:') 

322 for orb in cs._dropped_orbits: 

323 logger.info(' {}'.format(cs.cluster_list[orb.prototype_index])) 

324 counter = Counter(orb.order for orb in cs._orbits) 

325 for order in cs.cutoffs.orders: 

326 if counter[order] == 0: 

327 logger.warning(' Warning: No orbits exists for order {}'.format(order)) 

328 logger.info('') 

329 

330 

331def _populate_ofs_with_ets(cs): 

332 

333 R_inv_lookup = dict() 

334 for orbit_index, orbit in enumerate(cs.orbits): 

335 for of in orbit.orientation_families: 

336 R_inv_lookup_index = (of.symmetry_index, orbit.order) 

337 R_inv = R_inv_lookup.get(R_inv_lookup_index, None) 

338 if R_inv is None: 

339 R = cs.rotation_matrices[of.symmetry_index] 

340 R_inv_tmp = np.linalg.inv(R) 

341 R_inv = R_inv_tmp.astype(np.int64) 

342 assert np.allclose(R_inv, R_inv_tmp), (R_inv, R_inv_tmp) 

343 R_inv = rotation_tensor_as_matrix(R_inv, orbit.order) 

344 R_inv_lookup[R_inv_lookup_index] = R_inv 

345 of.eigentensors = [] 

346 for et in orbit.eigentensors: 

347 rotated_et = rotate_tensor_precalc(et, R_inv) 

348 assert rotated_et.dtype == np.int64 

349 of.eigentensors.append(rotated_et) 

350 

351 

352def _rotate_eigentensors(cs): 

353 V_invT = np.linalg.inv(cs._prim.cell.T) 

354 lookup = dict() 

355 for orb in cs.orbits: 

356 V_invT_tensormatrix = lookup.get(orb.order, None) 

357 if V_invT_tensormatrix is None: 

358 V_invT_tensormatrix = rotation_tensor_as_matrix(V_invT, orb.order) 

359 lookup[orb.order] = V_invT_tensormatrix 

360 orb.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in orb.eigentensors] # noqa 

361 for of in orb.orientation_families: 

362 of.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in of.eigentensors] # noqa 

363 

364 

365def _normalize_constraint_vectors(cs): 

366 M = cs._cvs 

367 norms = np.zeros(M.shape[1]) 

368 for c, v in zip(M.col, M.data): 

369 norms[c] += v**2 

370 for i in range(len(norms)): 

371 norms[i] = np.sqrt(norms[i]) 

372 for i, c in enumerate(M.col): 

373 M.data[i] /= norms[c] 

374 

375 

376def _rescale_eigentensors(cs): 

377 for orbit in cs.orbits: 

378 norm = cs.length_scale**orbit.order 

379 ets = orbit.eigentensors 

380 for i, et in enumerate(ets): 

381 ets[i] = et.astype(np.float64) 

382 ets[i] /= norm 

383 for of in orbit.orientation_families: 

384 ets = of.eigentensors 

385 for i, et in enumerate(ets): 

386 ets[i] = et.astype(np.float64) 

387 ets[i] /= norm 

388 

389 orders = [] 

390 for orbit in cs.orbits: 

391 for et in orbit.eigentensors: 

392 orders.append(orbit.order) 

393 

394 M = cs._cvs 

395 for i, r in enumerate(M.row): 

396 M.data[i] *= cs.length_scale**orders[r]