Coverage for hiphive/core/cluster_space_builder.py: 97%
247 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1import itertools
2import numpy as np
3import spglib as spg
4from collections import Counter
5from ase.neighborlist import NeighborList
7from .atoms import Atom, Atoms
8from .clusters import get_clusters
9from .orbits import get_orbits
10from ..input_output.logging_tools import logger
11from .eigentensors import create_eigentensors as _create_eigentensors
12from .tensors import rotate_tensor_precalc, rotation_tensor_as_matrix
13from .translational_constraints import create_constraint_map as _create_constraint_map
14from .utilities import BiMap, ase_atoms_to_spglib_tuple
17# TODO: Add longer description of each step
18# TODO: Be careful with side effects
19# TODO: Preferably the functions should not take the cs as input
20def build_cluster_space(cluster_space, prototype_structure):
22 """ The permutation list is an indexed fast lookup table for permutation
23 vectors.
24 """
25 logger.debug('Populate permutation list')
26 _create_permutations(cluster_space)
28 """ The primitive cell is calcualted by spglib and contains the cell
29 metric, basis and atomic numbers. After this the prototype structure is
30 disposed.
31 """
32 logger.debug('Get primitive cell')
33 _create_primitive_cell(cluster_space, prototype_structure)
35 """ The symmetries are calculated by spglib and the main information is the
36 rotation matrices and translation vectors for each symmetry in scaled
37 coordinates.
38 """
39 logger.debug('Get symmetries')
40 _create_symmetry_dataset(cluster_space)
42 """ The neigbor atoms to the center cell are stored in an indexed list. The
43 list contains all atoms within the maximum cutoff specified.
44 """
45 logger.debug('Find neighbors')
46 _create_atom_list(cluster_space)
48 """ Clusters are generated as combinations of the indices of the atoms in
49 the atom_list. If the cluster is to be included or not depends on the
50 specification of valid clusters in the cutoffs object. Often all atoms must
51 be within some distance from each other which may depend on both expansion
52 order and number of atoms in the cluster
53 """
54 logger.debug('Starting generating clusters')
55 _create_cluster_list(cluster_space)
56 logger.debug('Finished generating clusters')
58 """ The clusters are categorized into orbits and the eigensymmetries are
59 stored in each orbit.
60 """
61 logger.debug('Starting categorizing clusters')
62 _create_orbits(cluster_space)
63 logger.debug('Finished categorizing clusters')
65 """ The eigensymmetries from the previous step is used to generate valid
66 eigentensors
67 """
68 logger.debug('Starting finding eigentensors')
69 _create_eigentensors(cluster_space)
70 logger.debug('Finished finding eigentensors')
72 """ If some orbits can't have a force constant due to symmetry they are
73 dropped from the _orbits attribute but kept in a separate list named
74 _dropped_orbits
75 """
76 logger.debug('Dropping orbits...')
77 _drop_orbits(cluster_space)
79 """ Each orientation family gets populated with its rotated version of the
80 orbits eigentensors.
81 """
82 logger.debug('Rotating eigentensors into ofs...')
83 _populate_ofs_with_ets(cluster_space)
85 """ The matrix describing the mapping which preserves the global symmetries
86 is created (translational and rotational symmetry).
87 """
88 logger.debug('Constructing constraint map')
89 _create_constraint_map(cluster_space)
90 logger.info('Constraints:')
91 logger.info(' Acoustic: {}'.format(cluster_space.acoustic_sum_rules))
92 ndofs_by_order = {o: cluster_space.get_n_dofs_by_order(o) for o in cluster_space.cutoffs.orders}
93 logger.info(' Number of degrees of freedom: {}'.format(ndofs_by_order))
94 for order, count in ndofs_by_order.items():
95 if count == 0:
96 logger.warning(' Warning: No degrees of freedom exists for order {}'.format(order))
97 logger.info(' Total number of degrees of freedom: {}'.format(cluster_space.n_dofs))
99 """ The eigentensors are rescaled depending on order to create a more well
100 behaved system for fitting. NOTE! In here the constratints are rescaled too
101 """
102 logger.debug('Rescale eigentensors')
103 _rescale_eigentensors(cluster_space)
105 logger.debug('Normalize constraints')
106 _normalize_constraint_vectors(cluster_space)
108 logger.debug('Rotate tensors to Carteesian coordinates')
109 _rotate_eigentensors(cluster_space)
112# TODO: Actual input could be just the maximum order
113# TODO: No side effects, returns only the permutation BiMap
114def _create_permutations(cs):
115 orders = cs.cutoffs.orders
116 permutations = BiMap()
117 for order in orders:
118 for permutation in itertools.permutations(range(order)):
119 permutations.append(permutation)
120 cs._permutations = permutations
123# TODO: tolerances must be fixed in a coherent way. Prefarably via a config
124# object
125# TODO: Does the basis check need to be done? There might not be a problem that
126# it is close to 1 instead of 0 anymore. If thats the case it is better to keep
127# it as spglib returns it
128# TODO: Add good debug
129# TODO: Assert spos dot cell == pos. Sometimes the positions can be outside of
130# the cell and then there is a mismatch between wath is returned by
131# get_sclad_positions, positions and what spos dot cell gives (it should give
132# the position
133# TODO: Check basis to see if it can be represented by sympy. (in preparation
134# for rotational sum rules)
135# TODO: general function -> break out into utility function
136# TODO: Send the tolerance as a parameter instef of the whole cs.
137def _create_primitive_cell(cs, prototype_structure):
139 structure_as_tuple = ase_atoms_to_spglib_tuple(prototype_structure)
140 spgPrim = spg.standardize_cell(structure_as_tuple, no_idealize=True,
141 to_primitive=True, symprec=cs.symprec)
143 numbers_match = sorted(spgPrim[2]) == sorted(prototype_structure.numbers)
144 spg_cell_volume = np.abs(np.linalg.det(spgPrim[0].T))
145 prototype_cell_volume = np.abs(np.linalg.det(prototype_structure.cell.T))
146 # TODO: is symprec the best tolerance to use for volume check?
147 cell_volume_match = np.isclose(spg_cell_volume, prototype_cell_volume, atol=cs.symprec, rtol=0)
149 if numbers_match and cell_volume_match:
150 prim = Atoms(prototype_structure)
151 prim.wrap()
152 else:
153 basis = spgPrim[1]
154 if np.any(basis > (1 - cs.symprec)): 154 ↛ 155line 154 didn't jump to line 155 because the condition on line 154 was never true
155 logger.debug('Found basis close to 1:\n {}'.format(str(basis)))
156 basis = basis.round(8) % 1 # TODO
157 logger.debug('Wrapping to:\n {}'.format(str(basis)))
158 prim = Atoms(cell=spgPrim[0], scaled_positions=basis, numbers=spgPrim[2], pbc=True)
160 # log primitive cell information
161 logger.info('Primitive cell:')
162 logger.info(' Formula: {}'.format(prim.get_chemical_formula()))
163 logger.info((' Cell:' + '\n [{:9.5f} {:9.5f} {:9.5f}]'*3).format(
164 *prim.cell[0], *prim.cell[1], *prim.cell[2]))
165 logger.info(' Basis:')
166 if len(prim) < 5: 166 ↛ 170line 166 didn't jump to line 170 because the condition on line 166 was always true
167 for symbol, spos in zip(prim.get_chemical_symbols(), prim.basis):
168 logger.info(' {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(symbol, *spos))
169 else:
170 for sym, spos in zip(prim[:3].get_chemical_symbols(), prim[:3].basis):
171 logger.info(' {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(sym, *spos))
172 logger.info(' ...')
173 logger.info('')
174 cs._prim = prim
177# TODO: Fix how the tolerance is handled
178# TODO: Look over properties to acccess symmetry_dataset. Especially rotation,
179# translation and wyckoff
180# TODO: Sen prim and symprec as parameters
181def _create_symmetry_dataset(cs):
183 prim_as_tuple = ase_atoms_to_spglib_tuple(cs.primitive_structure)
184 symmetry_dataset = spg.get_symmetry_dataset(prim_as_tuple, symprec=cs.symprec)
185 cs._symmetry_dataset = symmetry_dataset
187 logger.info('Crystal symmetry:')
188 logger.info(' Spacegroup: {}'.format(cs.spacegroup))
189 logger.info(' Unique site: {}'.format(len(set(cs.wyckoff_sites))))
190 logger.info(' Symmetry operations: {}'.format(len(cs.rotation_matrices)))
191 logger.info(' symprec: {:.2e}'.format(cs.symprec))
192 logger.info('')
195# TODO: Fix how the tolerance is handled
196# TODO: Refactor the two runs of finding the neighbors
197# TODO: The bug that the cutoff is exactly on a shell might be a non issue.
198# TODO: It is possible to check that the clusters map out the orbits completely
199# TODO: Send in prim, cutoff and config and return atom_list instead
200def _create_atom_list(cs):
202 tol = cs.symprec
203 atom_list = BiMap()
205 # Populating the atom list with the center atoms
206 for i in range(len(cs._prim)):
207 atom_list.append(Atom(i, [0, 0, 0]))
209 logger.info('Cutoffs:')
210 logger.info(' Maximum cutoff: {}'.format(cs.cutoffs.max_cutoff))
212 # Find all the atoms which is neighbors to the atoms in the center cell
213 # The pair cutoff should be larger or equal than the others
214 cutoffs = [(cs.cutoffs.max_cutoff - tol) / 2] * len(cs._prim)
215 nl = NeighborList(cutoffs=cutoffs, skin=0, self_interaction=True, bothways=True)
216 nl.update(cs._prim)
217 for i in range(len(cs._prim)):
218 for index, offset in zip(*nl.get_neighbors(i)):
219 atom = Atom(index, offset)
220 if atom not in atom_list:
221 atom_list.append(atom)
223 nl = NeighborList(
224 cutoffs=[(cs.cutoffs.max_cutoff + tol) / 2] * len(cs._prim),
225 skin=0, self_interaction=True, bothways=True)
226 nl.update(cs._prim)
227 distance_from_cutoff = tol
228 for i in range(len(cs._prim)):
229 for index, offset in zip(*nl.get_neighbors(i)):
230 atom = Atom(index, offset)
231 # ... and check that no new atom is found
232 if atom not in atom_list:
233 pos = atom.pos(cs._prim.basis, cs._prim.cell)
234 distance = min(np.linalg.norm(pos - atom.position)
235 for atom in cs._prim) - cs.cutoffs.max_cutoff
236 distance_from_cutoff = min(distance, distance_from_cutoff)
238 if distance_from_cutoff != tol:
239 raise Exception('Maximum cutoff close to neighbor shell, change cutoff')
241 msg = ' Found {} center atom{} with {} images totaling {} atoms'.format(
242 len(cs._prim), 's' if len(cs._prim) > 1 else '',
243 len(atom_list) - len(cs._prim), len(atom_list))
244 logger.info(msg)
245 logger.info('')
247 cs._atom_list = atom_list
250# TODO: add atoms property to cs
251# TODO: Only inputs are prim, atom_list and cutoffs
252def _create_cluster_list(cs):
254 # Convert the atom list from site/offset to scaled positions
255 spos = [a.spos(cs._prim.basis) for a in cs._atom_list]
256 numbers = [cs._prim.numbers[a.site] for a in cs._atom_list]
258 # Make an atoms object out of the scaled positions
259 atoms = Atoms(cell=cs._prim.cell, scaled_positions=spos, numbers=numbers, pbc=False)
261 cs._cluster_filter.setup(atoms)
262 cs._cluster_list = get_clusters(atoms, cs.cutoffs, len(cs._prim))
264 logger.info('Clusters:')
265 counter = Counter(len(c) for c in cs._cluster_list)
266 logger.info(' Clusters: {}'.format(dict(counter)))
267 logger.info(' Total number of clusters: {}\n'.format(sum(counter.values())))
270def _create_orbits(cs):
271 # TODO: Check scaled/cart
272 cs._orbits = get_orbits(cs._cluster_list,
273 cs._atom_list,
274 cs.rotation_matrices,
275 cs.translation_vectors,
276 cs.permutations,
277 cs._prim,
278 cs.symprec)
279 orbits_to_drop = []
280 for i, orbit in enumerate(cs.orbits):
281 if not cs._cluster_filter(cs._cluster_list[orbit.prototype_index]):
282 orbits_to_drop.append(i)
284 reduced_orbits = []
285 cs._dropped_orbits = []
286 for i in range(len(cs.orbits)):
287 if i in orbits_to_drop:
288 cs._dropped_orbits.append(cs.orbits[i])
289 else:
290 reduced_orbits.append(cs.orbits[i])
291 cs._orbits = reduced_orbits
293 counter = Counter(orb.order for orb in cs._orbits)
294 logger.info('Orbits:')
295 logger.info(' Orbits: {}'.format(dict(counter)))
296 logger.info(' Total number of orbits: {}\n'.format(sum(counter.values())))
299def _drop_orbits(cs):
300 orbits_to_drop = []
301 for i, orbit in enumerate(cs.orbits):
302 if not orbit.eigentensors:
303 orbits_to_drop.append(i)
305 reduced_orbits = []
307 for i in range(len(cs.orbits)):
308 if i in orbits_to_drop:
309 cs._dropped_orbits.append(cs.orbits[i])
310 else:
311 reduced_orbits.append(cs.orbits[i])
312 cs._orbits = reduced_orbits
314 logger.info('Eigentensors:')
315 n_ets = dict()
316 for order in cs.cutoffs.orders:
317 n_ets[order] = sum(len(orb.eigentensors) for orb in cs.orbits if orb.order == order)
318 logger.info(' Eigentensors: {}'.format(n_ets))
319 logger.info(' Total number of parameters: {}'.format(sum(n_ets.values())))
320 if len(cs._dropped_orbits) > 0:
321 logger.info(' Discarded orbits:')
322 for orb in cs._dropped_orbits:
323 logger.info(' {}'.format(cs.cluster_list[orb.prototype_index]))
324 counter = Counter(orb.order for orb in cs._orbits)
325 for order in cs.cutoffs.orders:
326 if counter[order] == 0:
327 logger.warning(' Warning: No orbits exists for order {}'.format(order))
328 logger.info('')
331def _populate_ofs_with_ets(cs):
333 R_inv_lookup = dict()
334 for orbit_index, orbit in enumerate(cs.orbits):
335 for of in orbit.orientation_families:
336 R_inv_lookup_index = (of.symmetry_index, orbit.order)
337 R_inv = R_inv_lookup.get(R_inv_lookup_index, None)
338 if R_inv is None:
339 R = cs.rotation_matrices[of.symmetry_index]
340 R_inv_tmp = np.linalg.inv(R)
341 R_inv = R_inv_tmp.astype(np.int64)
342 assert np.allclose(R_inv, R_inv_tmp), (R_inv, R_inv_tmp)
343 R_inv = rotation_tensor_as_matrix(R_inv, orbit.order)
344 R_inv_lookup[R_inv_lookup_index] = R_inv
345 of.eigentensors = []
346 for et in orbit.eigentensors:
347 rotated_et = rotate_tensor_precalc(et, R_inv)
348 assert rotated_et.dtype == np.int64
349 of.eigentensors.append(rotated_et)
352def _rotate_eigentensors(cs):
353 V_invT = np.linalg.inv(cs._prim.cell.T)
354 lookup = dict()
355 for orb in cs.orbits:
356 V_invT_tensormatrix = lookup.get(orb.order, None)
357 if V_invT_tensormatrix is None:
358 V_invT_tensormatrix = rotation_tensor_as_matrix(V_invT, orb.order)
359 lookup[orb.order] = V_invT_tensormatrix
360 orb.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in orb.eigentensors] # noqa
361 for of in orb.orientation_families:
362 of.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in of.eigentensors] # noqa
365def _normalize_constraint_vectors(cs):
366 M = cs._cvs
367 norms = np.zeros(M.shape[1])
368 for c, v in zip(M.col, M.data):
369 norms[c] += v**2
370 for i in range(len(norms)):
371 norms[i] = np.sqrt(norms[i])
372 for i, c in enumerate(M.col):
373 M.data[i] /= norms[c]
376def _rescale_eigentensors(cs):
377 for orbit in cs.orbits:
378 norm = cs.length_scale**orbit.order
379 ets = orbit.eigentensors
380 for i, et in enumerate(ets):
381 ets[i] = et.astype(np.float64)
382 ets[i] /= norm
383 for of in orbit.orientation_families:
384 ets = of.eigentensors
385 for i, et in enumerate(ets):
386 ets[i] = et.astype(np.float64)
387 ets[i] /= norm
389 orders = []
390 for orbit in cs.orbits:
391 for et in orbit.eigentensors:
392 orders.append(orbit.order)
394 M = cs._cvs
395 for i, r in enumerate(M.row):
396 M.data[i] *= cs.length_scale**orders[r]