Coverage for hiphive/core/structure_alignment.py: 93%
114 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
1import itertools
2from typing import Union
3import ase
4import numpy as np
5import spglib as spg
6from . import atoms as atoms_module
7from ..input_output.logging_tools import logger
8from .utilities import ase_atoms_to_spglib_tuple
10logger = logger.getChild('relate_structures')
13def align_supercell(
14 supercell: ase.Atoms,
15 prim: ase.Atoms,
16 symprec: float = None,
17) -> tuple[ase.Atoms, np.ndarray, np.ndarray]:
18 """Rotates and translates a supercell configuration such that it is aligned
19 with the target primitive cell.
21 Parameters
22 ----------
23 supercell
24 Supercell configuration.
25 prim
26 Target primitive configuration.
27 symprec
28 Precision parameter forwarded to spglib.
30 Returns
31 -------
32 Aligned supercell configuration as well as rotation matrix
33 (`3x3` array) and translation vector (`3x1` array) that relate
34 the input to the aligned supercell configuration.
35 """
37 # TODO: Make sure the input is what we expect
39 # find rotation and translation
40 R, T = relate_structures(supercell, prim, symprec=symprec)
42 # Create the aligned system
43 aligned_supercell = rotate_atoms(supercell, R)
44 aligned_supercell.translate(T)
45 aligned_supercell.wrap()
46 return aligned_supercell, R, T
49def relate_structures(
50 reference: ase.Atoms,
51 target: ase.Atoms,
52 symprec: float = 1e-5,
53) -> tuple[np.ndarray, np.ndarray]:
54 """Finds rotation and translation operations that align two structures with
55 periodic boundary conditions.
56 The rotation and translation in Cartesian coordinates will map the
57 reference structure onto the target.
59 Aligning the reference with the target structure can be achieved via the
60 transformations::
62 R, T = relate_structures(atoms_ref, atoms_target)
63 atoms_ref_rotated = rotate_atoms(atoms_ref, R)
64 atoms_ref_rotated.translate(T)
65 atoms_ref_rotated.wrap()
66 atoms_ref_rotated == atoms_target
68 Parameters
69 ----------
70 reference
71 Reference structure to be mapped.
72 target
73 Target structure.
75 Returns
76 -------
77 A tuple comprising the rotation matrix in Cartesian coordinates (`3x3` array)
78 and the translation vector in Cartesian coordinates.
79 """
81 logger.debug('Reference atoms:')
82 _debug_log_atoms(reference)
84 reference_primitive_cell = get_primitive_cell(reference, symprec=symprec)
86 logger.debug('Reference primitive cell')
87 _debug_log_atoms(reference_primitive_cell)
89 logger.debug('Target atoms:')
90 _debug_log_atoms(target)
92 target_primitive_cell = get_primitive_cell(target, symprec=symprec)
94 logger.debug('Target primitive cell')
95 _debug_log_atoms(target_primitive_cell)
97 logger.debug('Sane check that primitive cells can match...')
98 _assert_structures_match(reference_primitive_cell, target_primitive_cell)
100 logger.debug('Finding rotations...')
101 rotations = _find_rotations(reference_primitive_cell.cell,
102 target_primitive_cell.cell)
104 logger.debug('Finding transformations...')
105 for R in rotations: 105 ↛ 113line 105 didn't jump to line 113 because the loop on line 105 didn't complete
106 rotated_reference_primitive_cell = \
107 rotate_atoms(reference_primitive_cell, R)
108 T = _find_translation(rotated_reference_primitive_cell,
109 target_primitive_cell)
110 if T is not None:
111 break
112 else:
113 raise Exception(('Found no translation!\n'
114 'Reference primitive cell basis:\n'
115 '{}\n'
116 'Target primitive cell basis:\n'
117 '{}')
118 .format(reference_primitive_cell.basis,
119 target_primitive_cell.basis))
121 logger.debug(('Found rotation\n'
122 '{}\n'
123 'and translation\n'
124 '{}')
125 .format(R, T))
127 return R, T
130def is_rotation(R: np.ndarray, cell_metric: np.ndarray = None):
131 """Checks if a rotation matrix is orthonormal.
132 A cell metric can be passed if the rotation matrix is in scaled coordinates
134 Parameters
135 ----------
136 R
137 Rotation matrix (`3x3` array).
138 cell_metric
139 Cell metric if the rotation is in scaled coordinates.
140 """
141 if not cell_metric: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true
142 cell_metric = np.eye(3)
144 V = cell_metric
145 V_inv = np.linalg.inv(V)
146 lhs = np.linalg.multi_dot([V_inv, R.T, V, V.T, R, V_inv.T])
148 return np.allclose(lhs, np.eye(3), atol=1e-4) # TODO: tol
151def _find_rotations(reference_cell_metric, target_cell_metric):
152 """ Generates all proper and improper rotations aligning two cell
153 metrics. """
155 rotations = []
156 V1 = reference_cell_metric
157 for perm in itertools.permutations([0, 1, 2]):
158 # Make sure the improper rotations are included
159 for inv in itertools.product([1, -1], repeat=3):
160 V2 = np.diag(inv) @ target_cell_metric[perm, :]
161 R = np.linalg.solve(V1, V2).T
162 # Make sure the rotation is orthonormal
163 if is_rotation(R):
164 for R_tmp in rotations:
165 if np.allclose(R, R_tmp): # TODO: tol 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true
166 break
167 else:
168 rotations.append(R)
170 assert rotations, ('Found no rotations! Reference cell metric:\n'
171 '{}\n'
172 'Target cell metric:\n'
173 '{}').format(reference_cell_metric, target_cell_metric)
175 logger.debug('Found {} rotations'.format(len(rotations)))
177 return rotations
180def _assert_structures_match(ref, prim):
181 """ Asserts the structures are compatible with respect to number of atoms,
182 atomic numbers and volume.
184 TODO: tol
185 """
187 if len(ref) != len(prim):
188 raise ValueError(
189 'Number of atoms in reference primitive cell {} does not match '
190 'target primitive {}'.format(len(ref), len(prim)))
192 if sorted(ref.numbers) != sorted(prim.numbers):
193 raise ValueError('Atomic numbers do not match\nReference: {}\nTarget:'
194 ' {}\n'.format(ref.numbers, prim.numbers))
196 if not np.isclose(ref.get_volume(), prim.get_volume()):
197 raise ValueError(
198 'Volume for reference primitive cell {} does not match target '
199 'primitive cell {}\n'.format(ref.get_volume(), prim.get_volume()))
202def get_primitive_cell(
203 atoms: ase.Atoms,
204 to_primitive: bool = True,
205 no_idealize: bool = True,
206 symprec: float = 1e-5,
207) -> atoms_module.Atoms:
208 """ Returns primitive cell obtained using spglib.
210 Parameters
211 ----------
212 atoms
213 Atomic structure.
214 to_primitive
215 Passed to spglib.
216 no_idealize
217 Passed to spglib.
218 symprec
219 Numerical tolerance; passed to spglib.
220 """
221 if not all(atoms.pbc): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 raise ValueError('atoms must have pbc.')
223 atoms_as_tuple = ase_atoms_to_spglib_tuple(atoms)
224 spg_primitive_cell = spg.standardize_cell(atoms_as_tuple, to_primitive=True,
225 no_idealize=True,
226 symprec=symprec)
227 primitive_cell = atoms_module.Atoms(cell=spg_primitive_cell[0],
228 scaled_positions=spg_primitive_cell[1],
229 numbers=spg_primitive_cell[2],
230 pbc=True)
231 return primitive_cell
234def _debug_log_atoms(atoms):
235 logger.debug('cell:\n{}'.format(atoms.cell))
236 logger.debug('spos:\n{}'.format(atoms.get_scaled_positions()))
237 logger.debug('pos:\n{}'.format(atoms.positions))
238 logger.debug('numbers:\n{}'.format(atoms.numbers))
241def rotate_atoms(
242 atoms: Union[ase.Atoms, atoms_module.Atoms],
243 rotation: np.ndarray,
244) -> atoms_module.Atoms:
245 """Rotates the cell and positions of `atoms` and returns a copy.
247 Parameters
248 ----------
249 atoms
250 Atomic structure.
251 rotation
252 Rotation matrix (`3x3` array).
253 """
254 cell = np.dot(rotation, atoms.cell.T).T
255 positions = np.dot(rotation, atoms.positions.T).T
256 return atoms_module.Atoms(cell=cell, positions=positions,
257 numbers=atoms.numbers, pbc=atoms.pbc)
260def _find_translation(reference, target):
261 """Returns the translation between two compatible atomic structures.
263 The two structures must describe the same structure when infinitely
264 repeated but differ by a translation.
266 Parameters
267 ----------
268 reference : ase.Atoms
269 target : ase.Atoms
271 Returns
272 -------
273 numpy.ndarray or None
274 translation vector or `None` if structures are incompatible
275 """
277 atoms = atoms_module.Atoms(cell=target.cell,
278 positions=reference.positions,
279 numbers=reference.numbers,
280 pbc=True)
281 atoms.wrap()
283 atoms_atom_0 = atoms[0]
284 for atom in target:
285 if atoms_atom_0.symbol != atom.symbol:
286 continue
287 T = atom.position - atoms_atom_0.position
288 atoms_copy = atoms.copy()
289 atoms_copy.positions += T
290 if are_nonpaired_configurations_equal(atoms_copy, target):
291 return T
292 return None
295def are_nonpaired_configurations_equal(
296 atoms1: ase.Atoms,
297 atoms2: ase.Atoms,
298) -> bool:
299 """ Checks whether two configurations are identical. To be considered
300 equal the structures must have the same cell metric, elemental
301 occupation, scaled positions (modulo one), and periodic boundary
302 conditions.
304 Unlike the ``__eq__`` operator of :class:`ase.Atoms` the order of the
305 atoms does not matter.
307 Parameters
308 ----------
309 atoms1
310 atoms2
312 Returns
313 -------
314 True if atoms are equal, False otherwise
316 """
317 # TODO: add tolerance
318 n_atoms = len(atoms1)
319 if not (np.allclose(atoms1.cell, atoms2.cell, atol=1e-4) and 319 ↛ 323line 319 didn't jump to line 323 because the condition on line 319 was never true
320 n_atoms == len(atoms2) and
321 sorted(atoms1.numbers) == sorted(atoms2.numbers) and
322 all(atoms1.pbc == atoms2.pbc)):
323 return False
324 new_cell = (atoms1.cell + atoms2.cell) / 2
325 pos = [a.position for a in atoms1] + [a.position for a in atoms2]
326 num = [a.number for a in atoms1] + [a.number for a in atoms2]
327 s3 = atoms_module.Atoms(cell=new_cell, positions=pos, numbers=num,
328 pbc=True)
329 for i in range(n_atoms):
330 for j in range(n_atoms, len(s3)):
331 d = s3.get_distance(i, j, mic=True)
332 if abs(d) < 1e-4: # TODO: tol
333 if s3[i].number != s3[j].number: 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true
334 return False
335 break
336 else:
337 return False
338 return True