Coverage for hiphive/utilities.py: 97%
121 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
1"""
2This module contains various support/utility functions.
3"""
5from typing import List, Tuple
6import numpy as np
8from ase import Atoms
9from ase.calculators.singlepoint import SinglePointCalculator
10from ase.geometry import find_mic
11from ase.geometry import get_distances
12from ase.neighborlist import neighbor_list
13from .cluster_space import ClusterSpace
14from .force_constants import ForceConstants
15from .input_output.logging_tools import logger
18logger = logger.getChild('utilities')
21def get_displacements(atoms: Atoms,
22 atoms_ideal: Atoms,
23 cell_tol: float = 1e-4) -> np.ndarray:
24 """Returns the the smallest possible displacements between a
25 displaced configuration relative to an ideal (reference)
26 configuration.
28 Notes
29 -----
30 * uses :func:`ase.geometry.find_mic`
31 * assumes periodic boundary conditions in all directions
33 Parameters
34 ----------
35 atoms
36 Configuration with displaced atoms.
37 atoms_ideal
38 Ideal configuration relative to which displacements are computed.
39 cell_tol
40 Cell tolerance; if cell missmatch more than tol value error is raised.
41 """
42 if not np.array_equal(atoms.numbers, atoms_ideal.numbers):
43 raise ValueError('Atomic numbers do not match.')
44 if np.linalg.norm(atoms.cell - atoms_ideal.cell) > cell_tol:
45 raise ValueError('Cells do not match.')
47 raw_position_diff = atoms.positions - atoms_ideal.positions
48 wrapped_mic_displacements = find_mic(raw_position_diff, atoms_ideal.cell, pbc=True)[0]
49 return wrapped_mic_displacements
52def _get_forces_from_atoms(atoms: Atoms, calc=None) -> np.ndarray:
53 """ Try to get forces from an atoms object """
55 # Check if two calculators are available
56 if atoms.calc is not None and calc is not None:
57 raise ValueError('Atoms.calc is not None and calculator was provided.')
59 # If calculator is provided as argument
60 if calc is not None:
61 atoms_tmp = atoms.copy()
62 atoms_tmp.calc = calc
63 forces_calc = atoms_tmp.get_forces()
64 if 'forces' in atoms.arrays:
65 if not np.allclose(forces_calc, atoms.get_array('forces')):
66 raise ValueError('Forces in atoms.arrays are different from the calculator forces.')
67 return forces_calc
69 # If calculator is attached
70 if atoms.calc is not None:
71 if not isinstance(atoms.calc, SinglePointCalculator): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 raise ValueError('atoms.calc is not a SinglePointCalculator.')
74 forces_calc = atoms.get_forces()
75 if 'forces' in atoms.arrays:
76 if not np.allclose(forces_calc, atoms.get_array('forces')):
77 raise ValueError('Forces in atoms.arrays are different from the calculator forces.')
78 return forces_calc
80 # No calculator attached or provided as argument, forces should therefore be in atoms.arrays
81 if 'forces' in atoms.arrays:
82 forces = atoms.get_array('forces')
83 else:
84 raise ValueError('Unable to find forces.')
85 return forces
88def prepare_structure(atoms: Atoms,
89 atoms_ideal: Atoms,
90 calc: SinglePointCalculator = None,
91 check_permutation: bool = True) -> Atoms:
92 """Prepare a structure in the format suitable for a
93 :class:`StructureContainer <hiphive.StructureContainer>`.
95 Either forces should be attached to input atoms object as an array,
96 or the atoms object should have a SinglePointCalculator attached to it containing
97 forces, or a calculator (calc) should be supplied.
99 Parameters
100 ----------
101 atoms
102 Input structure.
103 atoms_ideal
104 Reference structure relative to which displacements are computed.
105 check_permutation
106 Whether :func:`find_permutation` should be used or not.
107 calc
108 ASE calculator to use for computing forces.
110 Returns
111 -------
112 Prepared ASE :class:`Atoms` object with forces and displacements as arrays.
113 """
114 # get forces
115 forces = _get_forces_from_atoms(atoms, calc=calc)
117 # setup new atoms
118 if check_permutation:
119 perm = find_permutation(atoms, atoms_ideal)
120 else:
121 perm = np.array([f for f in range(len(atoms))])
122 atoms_new = atoms.copy()
123 atoms_new = atoms_new[perm]
124 atoms_new.arrays['forces'] = forces[perm]
125 disps = get_displacements(atoms_new, atoms_ideal)
126 atoms_new.arrays['displacements'] = disps
127 atoms_new.positions = atoms_ideal.positions
129 return atoms_new
132def prepare_structures(structures: List[Atoms],
133 atoms_ideal: Atoms,
134 calc: SinglePointCalculator = None,
135 check_permutation: bool = True) -> List[Atoms]:
136 """Prepares a set of structures in the format suitable for adding them to
137 a :class:`StructureContainer <hiphive.StructureContainer>`.
139 :attr:`structures` should represent a list of supercells with displacements
140 while `atoms_ideal` should provide the ideal reference structure (without
141 displacements) for the given structures.
143 The structures that are returned will have their positions reset to the
144 ideal structures. Displacements and forces will be added as arrays to the
145 atoms objects.
147 If no calculator is provided, then there must be an ASE
148 :class:`SinglePointCalculator <ase.calculators.singlepoint>` object attached to
149 the structures or the forces should already be attached as
150 arrays to the structures.
152 If a calculator is provided then it will be used to compute the forces for
153 all structures.
155 Example
156 -------
158 The following example illustrates the use of this function::
160 db = connect('dft_training_structures.db')
161 training_structures = [row.toatoms() for row in db.select()]
162 training_structures = prepare_structures(training_structures, atoms_ideal)
163 for s in training_structures:
164 sc.add_structure(s)
166 Parameters
167 ----------
168 structures
169 List of input displaced structures.
170 atoms_ideal
171 Reference structure relative to which displacements are computed.
172 calc
173 ASE calculator to use for computing forces.
175 Returns
176 -------
177 List of prepared structures with forces and displacements as arrays.
178 """
179 return [prepare_structure(s, atoms_ideal, calc, check_permutation) for s in structures]
182def find_permutation(atoms: Atoms, atoms_ref: Atoms) -> List[int]:
183 """ Returns the best permutation of atoms for mapping one
184 configuration onto another.
186 Parameters
187 ----------
188 atoms
189 configuration to be permuted
190 atoms_ref
191 configuration onto which to map
193 Examples
194 --------
195 After obtaining the permutation via ``p = find_permutation(atoms1, atoms2)``
196 the reordered structure ``atoms1[p]`` will give the closest match
197 to ``atoms2``.
198 """
199 assert np.linalg.norm(atoms.cell - atoms_ref.cell) < 1e-6
200 permutation = []
201 for i in range(len(atoms_ref)):
202 dist_row = get_distances(
203 atoms.positions, atoms_ref.positions[i], cell=atoms_ref.cell, pbc=True)[1][:, 0]
204 permutation.append(np.argmin(dist_row))
206 if len(set(permutation)) != len(permutation):
207 raise Exception('Duplicates in permutation')
208 for i, p in enumerate(permutation):
209 if atoms[p].symbol != atoms_ref[i].symbol:
210 raise Exception('Matching lattice sites have different occupation')
211 return permutation
214class Shell:
215 """
216 Neighbor Shell class
218 Parameters
219 ----------
220 types : Union[list, tuple]
221 Atomic types for neighbor shell.
222 distance : float
223 Interatomic distance for neighbor shell.
224 count : int
225 Number of pairs in the neighbor shell.
226 """
228 def __init__(self,
229 types: List[str],
230 distance: float,
231 count: int = 0):
232 self.types = types
233 self.distance = distance
234 self.count = count
236 def __str__(self):
237 s = '{}-{} distance: {:10.6f} count: {}'.format(*self.types, self.distance, self.count)
238 return s
240 __repr__ = __str__
243def get_neighbor_shells(atoms: Atoms,
244 cutoff: float,
245 dist_tol: float = 1e-5) -> List[Shell]:
246 """ Returns a list of neighbor shells.
248 Distances are grouped into shells via the following algorithm:
250 1. Find smallest atomic distance `d_min`
252 2. Find all pair distances in the range `d_min + 1 * dist_tol`
254 3. Construct a shell from these and pop them from distance list
256 4. Go to 1.
258 Parameters
259 ----------
260 atoms
261 Configuration used for finding shells.
262 cutoff
263 Exclude neighbor shells which have a distance larger than this value.
264 dist_tol
265 Distance tolerance.
266 """
268 # get distances
269 ijd = neighbor_list('ijd', atoms, cutoff)
270 ijd = list(zip(*ijd))
271 ijd.sort(key=lambda x: x[2])
273 # sort into shells
274 symbols = atoms.get_chemical_symbols()
275 shells = []
276 for i, j, d in ijd:
277 types = tuple(sorted([symbols[i], symbols[j]]))
278 for shell in shells:
279 if abs(d - shell.distance) < dist_tol and types == shell.types:
280 shell.count += 1
281 break
282 else:
283 shell = Shell(types, d, 1)
284 shells.append(shell)
285 shells.sort(key=lambda x: (x.distance, x.types, x.count))
287 # warning if two shells are within 2 * tol
288 for i, s1 in enumerate(shells):
289 for j, s2 in enumerate(shells[i+1:]):
290 if s1.types != s2.types:
291 continue
292 if not s1.distance < s2.distance - 2 * dist_tol:
293 logger.warning('Found two shells within 2 * dist_tol')
295 return shells
298def extract_parameters(fcs: ForceConstants,
299 cs: ClusterSpace,
300 sanity_check: bool = True,
301 lstsq_method: str = 'numpy') \
302 -> Tuple[np.ndarray, np.ndarray, int, np.ndarray]:
303 """ Extracts parameters from force constants.
306 This function can be used to extract parameters to create a
307 ForceConstantPotential from a known set of force constants.
308 The return values come from NumPy's `lstsq function
309 <https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html>`_
310 or from SciPy's `sparse lsqr function
311 <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html>`_.
312 Using `lstsq_method='scipy'` might be faster and have a smaller memory footprint for large
313 systems, at the expense of some accuracy. This is due to the use of sparse matrices
314 and an iterative solver.
316 Parameters
317 ----------
318 fcs
319 Force constants.
320 cs
321 Cluster space.
322 sanity_check
323 Bool whether or not to perform a sanity check by computing the relative error between
324 the input fcs and the output FCs.
325 lstsq_method
326 Method to use when making a least-squares fit of a :class:`ForceConstantModel` to the
327 given FCs, allowed values are `'numpy'` for :func:`np.linalg.lstsq`
328 and `'scipy'` for :func:`scipy.sparse.linalg.lsqr`.
330 Returns
331 -------
332 Parameters that together with the ClusterSpace give the best representation of the FCs.
333 """
334 from .force_constant_model import ForceConstantModel
335 from .force_constant_potential import ForceConstantPotential
336 from scipy.sparse.linalg import lsqr
338 if lstsq_method not in ['numpy', 'scipy']: 338 ↛ 339line 338 didn't jump to line 339 because the condition on line 338 was never true
339 raise ValueError('lstsq_method must be either numpy or scipy')
341 # extract the parameters
342 fcm = ForceConstantModel(fcs.supercell, cs)
343 # If the cluster space large, a sparse least squares solver is faster
344 if lstsq_method == 'numpy':
345 A, b = fcm.get_fcs_sensing(fcs, sparse=False)
346 parameters = np.linalg.lstsq(A, b, rcond=None)[0]
347 elif lstsq_method == 'scipy': 347 ↛ 353line 347 didn't jump to line 353 because the condition on line 347 was always true
348 A, b = fcm.get_fcs_sensing(fcs, sparse=True)
349 # set minimal tolerances to maximize iterative least squares accuracy
350 parameters = lsqr(A, b, atol=0, btol=0, conlim=0)[0]
352 # calculate the relative force constant error
353 if sanity_check: 353 ↛ 362line 353 didn't jump to line 362 because the condition on line 353 was always true
354 fcp = ForceConstantPotential(cs, parameters)
355 fcs_hiphive = fcp.get_force_constants(fcs.supercell)
356 for order in cs.cutoffs.orders:
357 fc_original = fcs.get_fc_array(order=order)
358 fc_reconstructed = fcs_hiphive.get_fc_array(order=order)
359 rel_error = np.linalg.norm(fc_original-fc_reconstructed) / np.linalg.norm(fc_original)
360 print(f'Force constant reconstruction error order {order}: {100*rel_error:9.4f}%')
362 return parameters