Coverage for hiphive/utilities.py: 97%
121 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1"""
2This module contains various support/utility functions.
3"""
5from typing import List, Tuple
6import numpy as np
8from ase import Atoms
9from ase.calculators.singlepoint import SinglePointCalculator
10from ase.geometry import find_mic
11from ase.geometry import get_distances
12from ase.neighborlist import neighbor_list
13from .cluster_space import ClusterSpace
14from .force_constants import ForceConstants
15from .input_output.logging_tools import logger
18logger = logger.getChild('utilities')
21def get_displacements(atoms: Atoms,
22 atoms_ideal: Atoms,
23 cell_tol: float = 1e-4) -> np.ndarray:
24 """Returns the the smallest possible displacements between a
25 displaced configuration relative to an ideal (reference)
26 configuration.
28 Notes
29 -----
30 * uses :func:`ase.geometry.find_mic`
31 * assumes periodic boundary conditions in all directions
33 Parameters
34 ----------
35 atoms
36 configuration with displaced atoms
37 atoms_ideal
38 ideal configuration relative to which displacements are computed
39 cell_tol
40 cell tolerance; if cell missmatch more than tol value error is raised
41 """
42 if not np.array_equal(atoms.numbers, atoms_ideal.numbers):
43 raise ValueError('Atomic numbers do not match.')
44 if np.linalg.norm(atoms.cell - atoms_ideal.cell) > cell_tol:
45 raise ValueError('Cells do not match.')
47 raw_position_diff = atoms.positions - atoms_ideal.positions
48 wrapped_mic_displacements = find_mic(raw_position_diff, atoms_ideal.cell, pbc=True)[0]
49 return wrapped_mic_displacements
52def _get_forces_from_atoms(atoms: Atoms, calc=None) -> np.ndarray:
53 """ Try to get forces from an atoms object """
55 # Check if two calculators are available
56 if atoms.calc is not None and calc is not None:
57 raise ValueError('Atoms.calc is not None and calculator was provided')
59 # If calculator is provided as argument
60 if calc is not None:
61 atoms_tmp = atoms.copy()
62 atoms_tmp.calc = calc
63 forces_calc = atoms_tmp.get_forces()
64 if 'forces' in atoms.arrays:
65 if not np.allclose(forces_calc, atoms.get_array('forces')):
66 raise ValueError('Forces in atoms.arrays are different from the calculator forces')
67 return forces_calc
69 # If calculator is attached
70 if atoms.calc is not None:
71 if not isinstance(atoms.calc, SinglePointCalculator): 71 ↛ 72line 71 didn't jump to line 72 because the condition on line 71 was never true
72 raise ValueError('atoms.calc is not a SinglePointCalculator')
74 forces_calc = atoms.get_forces()
75 if 'forces' in atoms.arrays:
76 if not np.allclose(forces_calc, atoms.get_array('forces')):
77 raise ValueError('Forces in atoms.arrays are different from the calculator forces')
78 return forces_calc
80 # No calculator attached or provided as argument, forces should therefore be in atoms.arrays
81 if 'forces' in atoms.arrays:
82 forces = atoms.get_array('forces')
83 else:
84 raise ValueError('Unable to find forces')
85 return forces
88def prepare_structure(atoms: Atoms,
89 atoms_ideal: Atoms,
90 calc: SinglePointCalculator = None,
91 check_permutation: bool = True) -> Atoms:
92 """Prepare a structure in the format suitable for a
93 :class:`StructureContainer <hiphive.StructureContainer>`.
95 Either forces should be attached to input atoms object as an array,
96 or the atoms object should have a SinglePointCalculator attached to it containing
97 forces, or a calculator (calc) should be supplied.
99 Parameters
100 ----------
101 atoms
102 input structure
103 atoms_ideal
104 reference structure relative to which displacements are computed
105 check_permutation
106 whether find_permutation should be used or not
107 calc
108 ASE calculator used for computing forces
110 Returns
111 -------
112 ASE atoms object
113 prepared ASE atoms object with forces and displacements as arrays
114 """
115 # get forces
116 forces = _get_forces_from_atoms(atoms, calc=calc)
118 # setup new atoms
119 if check_permutation:
120 perm = find_permutation(atoms, atoms_ideal)
121 else:
122 perm = np.array([f for f in range(len(atoms))])
123 atoms_new = atoms.copy()
124 atoms_new = atoms_new[perm]
125 atoms_new.arrays['forces'] = forces[perm]
126 disps = get_displacements(atoms_new, atoms_ideal)
127 atoms_new.arrays['displacements'] = disps
128 atoms_new.positions = atoms_ideal.positions
130 return atoms_new
133def prepare_structures(structures: List[Atoms],
134 atoms_ideal: Atoms,
135 calc: SinglePointCalculator = None,
136 check_permutation: bool = True) -> List[Atoms]:
137 """Prepares a set of structures in the format suitable for adding them to
138 a :class:`StructureContainer <hiphive.StructureContainer>`.
140 `structures` should represent a list of supercells with displacements
141 while `atoms_ideal` should provide the ideal reference structure (without
142 displacements) for the given structures.
144 The structures that are returned will have their positions reset to the
145 ideal structures. Displacements and forces will be added as arrays to the
146 atoms objects.
148 If no calculator is provided, then there must be an ASE
149 `SinglePointCalculator <ase.calculators.singlepoint>` object attached to
150 the structures or the forces should already be attached as
151 arrays to the structures.
153 If a calculator is provided then it will be used to compute the forces for
154 all structures.
156 Example
157 -------
159 The following example illustrates the use of this function::
161 db = connect('dft_training_structures.db')
162 training_structures = [row.toatoms() for row in db.select()]
163 training_structures = prepare_structures(training_structures, atoms_ideal)
164 for s in training_structures:
165 sc.add_structure(s)
167 Parameters
168 ----------
169 structures
170 list of input displaced structures
171 atoms_ideal
172 reference structure relative to which displacements are computed
173 calc
174 ASE calculator used for computing forces
176 Returns
177 -------
178 list of prepared structures with forces and displacements as arrays
179 """
180 return [prepare_structure(s, atoms_ideal, calc, check_permutation) for s in structures]
183def find_permutation(atoms: Atoms, atoms_ref: Atoms) -> List[int]:
184 """ Returns the best permutation of atoms for mapping one
185 configuration onto another.
187 Parameters
188 ----------
189 atoms
190 configuration to be permuted
191 atoms_ref
192 configuration onto which to map
194 Examples
195 --------
196 After obtaining the permutation via ``p = find_permutation(atoms1, atoms2)``
197 the reordered structure ``atoms1[p]`` will give the closest match
198 to ``atoms2``.
199 """
200 assert np.linalg.norm(atoms.cell - atoms_ref.cell) < 1e-6
201 permutation = []
202 for i in range(len(atoms_ref)):
203 dist_row = get_distances(
204 atoms.positions, atoms_ref.positions[i], cell=atoms_ref.cell, pbc=True)[1][:, 0]
205 permutation.append(np.argmin(dist_row))
207 if len(set(permutation)) != len(permutation):
208 raise Exception('Duplicates in permutation')
209 for i, p in enumerate(permutation):
210 if atoms[p].symbol != atoms_ref[i].symbol:
211 raise Exception('Matching lattice sites have different occupation')
212 return permutation
215class Shell:
216 """
217 Neighbor Shell class
219 Parameters
220 ----------
221 types : list or tuple
222 atomic types for neighbor shell
223 distance : float
224 interatomic distance for neighbor shell
225 count : int
226 number of pairs in the neighbor shell
227 """
229 def __init__(self,
230 types: List[str],
231 distance: float,
232 count: int = 0):
233 self.types = types
234 self.distance = distance
235 self.count = count
237 def __str__(self):
238 s = '{}-{} distance: {:10.6f} count: {}'.format(*self.types, self.distance, self.count)
239 return s
241 __repr__ = __str__
244def get_neighbor_shells(atoms: Atoms,
245 cutoff: float,
246 dist_tol: float = 1e-5) -> List[Shell]:
247 """ Returns a list of neighbor shells.
249 Distances are grouped into shells via the following algorithm:
251 1. Find smallest atomic distance `d_min`
253 2. Find all pair distances in the range `d_min + 1 * dist_tol`
255 3. Construct a shell from these and pop them from distance list
257 4. Go to 1.
259 Parameters
260 ----------
261 atoms
262 configuration used for finding shells
263 cutoff
264 exclude neighbor shells which have a distance larger than this value
265 dist_tol
266 distance tolerance
267 """
269 # get distances
270 ijd = neighbor_list('ijd', atoms, cutoff)
271 ijd = list(zip(*ijd))
272 ijd.sort(key=lambda x: x[2])
274 # sort into shells
275 symbols = atoms.get_chemical_symbols()
276 shells = []
277 for i, j, d in ijd:
278 types = tuple(sorted([symbols[i], symbols[j]]))
279 for shell in shells:
280 if abs(d - shell.distance) < dist_tol and types == shell.types:
281 shell.count += 1
282 break
283 else:
284 shell = Shell(types, d, 1)
285 shells.append(shell)
286 shells.sort(key=lambda x: (x.distance, x.types, x.count))
288 # warning if two shells are within 2 * tol
289 for i, s1 in enumerate(shells):
290 for j, s2 in enumerate(shells[i+1:]):
291 if s1.types != s2.types:
292 continue
293 if not s1.distance < s2.distance - 2 * dist_tol:
294 logger.warning('Found two shells within 2 * dist_tol')
296 return shells
299def extract_parameters(fcs: ForceConstants,
300 cs: ClusterSpace,
301 sanity_check: bool = True,
302 lstsq_method: str = 'numpy') \
303 -> Tuple[np.ndarray, np.ndarray, int, np.ndarray]:
304 """ Extracts parameters from force constants.
307 This function can be used to extract parameters to create a
308 ForceConstantPotential from a known set of force constants.
309 The return values come from NumPy's `lstsq function
310 <https://docs.scipy.org/doc/numpy/reference/generated/numpy.linalg.lstsq.html>`_
311 or from SciPy's `sparse lsqr function
312 <https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.lsqr.html>`_.
313 Using `lstsq_method='scipy'` might be faster and have a smaller memory footprint for large
314 systems, at the expense of some accuracy. This is due to the use of sparse matrices
315 and an iterative solver.
317 Parameters
318 ----------
319 fcs
320 force constants
321 cs
322 cluster space
323 sanity_check
324 bool whether or not to perform a sanity check by computing the relative error between
325 the input fcs and the output fcs
326 lstsq_method
327 method to use when making a least squares fit of a ForceConstantModel to the given fcs,
328 allowed values are 'numpy' for `np.linalg.lstsq` or 'scipy' `for scipy.sparse.linalg.lsqr`
330 Returns
331 -------
332 parameters
333 parameters that together with the ClusterSpace generates the best representation of the FCs
335 """
336 from .force_constant_model import ForceConstantModel
337 from .force_constant_potential import ForceConstantPotential
338 from scipy.sparse.linalg import lsqr
340 if lstsq_method not in ['numpy', 'scipy']: 340 ↛ 341line 340 didn't jump to line 341 because the condition on line 340 was never true
341 raise ValueError('lstsq_method must be either numpy or scipy')
343 # extract the parameters
344 fcm = ForceConstantModel(fcs.supercell, cs)
345 # If the cluster space large, a sparse least squares solver is faster
346 if lstsq_method == 'numpy':
347 A, b = fcm.get_fcs_sensing(fcs, sparse=False)
348 parameters = np.linalg.lstsq(A, b, rcond=None)[0]
349 elif lstsq_method == 'scipy': 349 ↛ 355line 349 didn't jump to line 355 because the condition on line 349 was always true
350 A, b = fcm.get_fcs_sensing(fcs, sparse=True)
351 # set minimal tolerances to maximize iterative least squares accuracy
352 parameters = lsqr(A, b, atol=0, btol=0, conlim=0)[0]
354 # calculate the relative force constant error
355 if sanity_check: 355 ↛ 364line 355 didn't jump to line 364 because the condition on line 355 was always true
356 fcp = ForceConstantPotential(cs, parameters)
357 fcs_hiphive = fcp.get_force_constants(fcs.supercell)
358 for order in cs.cutoffs.orders:
359 fc_original = fcs.get_fc_array(order=order)
360 fc_reconstructed = fcs_hiphive.get_fc_array(order=order)
361 rel_error = np.linalg.norm(fc_original-fc_reconstructed) / np.linalg.norm(fc_original)
362 print(f'Force constant reconstruction error order {order}: {100*rel_error:9.4f}%')
364 return parameters