Coverage for hiphive/core/structure_alignment.py: 93%

114 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-01 17:04 +0000

1import itertools 

2from typing import Union 

3import ase 

4import numpy as np 

5import spglib as spg 

6from . import atoms as atoms_module 

7from ..input_output.logging_tools import logger 

8from .utilities import ase_atoms_to_spglib_tuple 

9 

10logger = logger.getChild('relate_structures') 

11 

12 

13def align_supercell( 

14 supercell: ase.Atoms, 

15 prim: ase.Atoms, 

16 symprec: float = None, 

17) -> tuple[ase.Atoms, np.ndarray, np.ndarray]: 

18 """Rotates and translates a supercell configuration such that it is aligned 

19 with the target primitive cell. 

20 

21 Parameters 

22 ---------- 

23 supercell 

24 Supercell configuration. 

25 prim 

26 Target primitive configuration. 

27 symprec 

28 Precision parameter forwarded to spglib. 

29 

30 Returns 

31 ------- 

32 Aligned supercell configuration as well as rotation matrix 

33 (`3x3` array) and translation vector (`3x1` array) that relate 

34 the input to the aligned supercell configuration. 

35 """ 

36 

37 # TODO: Make sure the input is what we expect 

38 

39 # find rotation and translation 

40 R, T = relate_structures(supercell, prim, symprec=symprec) 

41 

42 # Create the aligned system 

43 aligned_supercell = rotate_atoms(supercell, R) 

44 aligned_supercell.translate(T) 

45 aligned_supercell.wrap() 

46 return aligned_supercell, R, T 

47 

48 

49def relate_structures( 

50 reference: ase.Atoms, 

51 target: ase.Atoms, 

52 symprec: float = 1e-5, 

53) -> tuple[np.ndarray, np.ndarray]: 

54 """Finds rotation and translation operations that align two structures with 

55 periodic boundary conditions. 

56 The rotation and translation in Cartesian coordinates will map the 

57 reference structure onto the target. 

58 

59 Aligning the reference with the target structure can be achieved via the 

60 transformations:: 

61 

62 R, T = relate_structures(atoms_ref, atoms_target) 

63 atoms_ref_rotated = rotate_atoms(atoms_ref, R) 

64 atoms_ref_rotated.translate(T) 

65 atoms_ref_rotated.wrap() 

66 atoms_ref_rotated == atoms_target 

67 

68 Parameters 

69 ---------- 

70 reference 

71 Reference structure to be mapped. 

72 target 

73 Target structure. 

74 

75 Returns 

76 ------- 

77 A tuple comprising the rotation matrix in Cartesian coordinates (`3x3` array) 

78 and the translation vector in Cartesian coordinates. 

79 """ 

80 

81 logger.debug('Reference atoms:') 

82 _debug_log_atoms(reference) 

83 

84 reference_primitive_cell = get_primitive_cell(reference, symprec=symprec) 

85 

86 logger.debug('Reference primitive cell') 

87 _debug_log_atoms(reference_primitive_cell) 

88 

89 logger.debug('Target atoms:') 

90 _debug_log_atoms(target) 

91 

92 target_primitive_cell = get_primitive_cell(target, symprec=symprec) 

93 

94 logger.debug('Target primitive cell') 

95 _debug_log_atoms(target_primitive_cell) 

96 

97 logger.debug('Sane check that primitive cells can match...') 

98 _assert_structures_match(reference_primitive_cell, target_primitive_cell) 

99 

100 logger.debug('Finding rotations...') 

101 rotations = _find_rotations(reference_primitive_cell.cell, 

102 target_primitive_cell.cell) 

103 

104 logger.debug('Finding transformations...') 

105 for R in rotations: 105 ↛ 113line 105 didn't jump to line 113 because the loop on line 105 didn't complete

106 rotated_reference_primitive_cell = \ 

107 rotate_atoms(reference_primitive_cell, R) 

108 T = _find_translation(rotated_reference_primitive_cell, 

109 target_primitive_cell) 

110 if T is not None: 

111 break 

112 else: 

113 raise Exception(('Found no translation!\n' 

114 'Reference primitive cell basis:\n' 

115 '{}\n' 

116 'Target primitive cell basis:\n' 

117 '{}') 

118 .format(reference_primitive_cell.basis, 

119 target_primitive_cell.basis)) 

120 

121 logger.debug(('Found rotation\n' 

122 '{}\n' 

123 'and translation\n' 

124 '{}') 

125 .format(R, T)) 

126 

127 return R, T 

128 

129 

130def is_rotation(R: np.ndarray, cell_metric: np.ndarray = None): 

131 """Checks if a rotation matrix is orthonormal. 

132 A cell metric can be passed if the rotation matrix is in scaled coordinates 

133 

134 Parameters 

135 ---------- 

136 R 

137 Rotation matrix (`3x3` array). 

138 cell_metric 

139 Cell metric if the rotation is in scaled coordinates. 

140 """ 

141 if not cell_metric: 141 ↛ 144line 141 didn't jump to line 144 because the condition on line 141 was always true

142 cell_metric = np.eye(3) 

143 

144 V = cell_metric 

145 V_inv = np.linalg.inv(V) 

146 lhs = np.linalg.multi_dot([V_inv, R.T, V, V.T, R, V_inv.T]) 

147 

148 return np.allclose(lhs, np.eye(3), atol=1e-4) # TODO: tol 

149 

150 

151def _find_rotations(reference_cell_metric, target_cell_metric): 

152 """ Generates all proper and improper rotations aligning two cell 

153 metrics. """ 

154 

155 rotations = [] 

156 V1 = reference_cell_metric 

157 for perm in itertools.permutations([0, 1, 2]): 

158 # Make sure the improper rotations are included 

159 for inv in itertools.product([1, -1], repeat=3): 

160 V2 = np.diag(inv) @ target_cell_metric[perm, :] 

161 R = np.linalg.solve(V1, V2).T 

162 # Make sure the rotation is orthonormal 

163 if is_rotation(R): 

164 for R_tmp in rotations: 

165 if np.allclose(R, R_tmp): # TODO: tol 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true

166 break 

167 else: 

168 rotations.append(R) 

169 

170 assert rotations, ('Found no rotations! Reference cell metric:\n' 

171 '{}\n' 

172 'Target cell metric:\n' 

173 '{}').format(reference_cell_metric, target_cell_metric) 

174 

175 logger.debug('Found {} rotations'.format(len(rotations))) 

176 

177 return rotations 

178 

179 

180def _assert_structures_match(ref, prim): 

181 """ Asserts the structures are compatible with respect to number of atoms, 

182 atomic numbers and volume. 

183 

184 TODO: tol 

185 """ 

186 

187 if len(ref) != len(prim): 

188 raise ValueError( 

189 'Number of atoms in reference primitive cell {} does not match ' 

190 'target primitive {}'.format(len(ref), len(prim))) 

191 

192 if sorted(ref.numbers) != sorted(prim.numbers): 

193 raise ValueError('Atomic numbers do not match\nReference: {}\nTarget:' 

194 ' {}\n'.format(ref.numbers, prim.numbers)) 

195 

196 if not np.isclose(ref.get_volume(), prim.get_volume()): 

197 raise ValueError( 

198 'Volume for reference primitive cell {} does not match target ' 

199 'primitive cell {}\n'.format(ref.get_volume(), prim.get_volume())) 

200 

201 

202def get_primitive_cell( 

203 atoms: ase.Atoms, 

204 to_primitive: bool = True, 

205 no_idealize: bool = True, 

206 symprec: float = 1e-5, 

207) -> atoms_module.Atoms: 

208 """ Returns primitive cell obtained using spglib. 

209 

210 Parameters 

211 ---------- 

212 atoms 

213 Atomic structure. 

214 to_primitive 

215 Passed to spglib. 

216 no_idealize 

217 Passed to spglib. 

218 symprec 

219 Numerical tolerance; passed to spglib. 

220 """ 

221 if not all(atoms.pbc): 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 raise ValueError('atoms must have pbc.') 

223 atoms_as_tuple = ase_atoms_to_spglib_tuple(atoms) 

224 spg_primitive_cell = spg.standardize_cell(atoms_as_tuple, to_primitive=True, 

225 no_idealize=True, 

226 symprec=symprec) 

227 primitive_cell = atoms_module.Atoms(cell=spg_primitive_cell[0], 

228 scaled_positions=spg_primitive_cell[1], 

229 numbers=spg_primitive_cell[2], 

230 pbc=True) 

231 return primitive_cell 

232 

233 

234def _debug_log_atoms(atoms): 

235 logger.debug('cell:\n{}'.format(atoms.cell)) 

236 logger.debug('spos:\n{}'.format(atoms.get_scaled_positions())) 

237 logger.debug('pos:\n{}'.format(atoms.positions)) 

238 logger.debug('numbers:\n{}'.format(atoms.numbers)) 

239 

240 

241def rotate_atoms( 

242 atoms: Union[ase.Atoms, atoms_module.Atoms], 

243 rotation: np.ndarray, 

244) -> atoms_module.Atoms: 

245 """Rotates the cell and positions of `atoms` and returns a copy. 

246 

247 Parameters 

248 ---------- 

249 atoms 

250 Atomic structure. 

251 rotation 

252 Rotation matrix (`3x3` array). 

253 """ 

254 cell = np.dot(rotation, atoms.cell.T).T 

255 positions = np.dot(rotation, atoms.positions.T).T 

256 return atoms_module.Atoms(cell=cell, positions=positions, 

257 numbers=atoms.numbers, pbc=atoms.pbc) 

258 

259 

260def _find_translation(reference, target): 

261 """Returns the translation between two compatible atomic structures. 

262 

263 The two structures must describe the same structure when infinitely 

264 repeated but differ by a translation. 

265 

266 Parameters 

267 ---------- 

268 reference : ase.Atoms 

269 target : ase.Atoms 

270 

271 Returns 

272 ------- 

273 numpy.ndarray or None 

274 translation vector or `None` if structures are incompatible 

275 """ 

276 

277 atoms = atoms_module.Atoms(cell=target.cell, 

278 positions=reference.positions, 

279 numbers=reference.numbers, 

280 pbc=True) 

281 atoms.wrap() 

282 

283 atoms_atom_0 = atoms[0] 

284 for atom in target: 

285 if atoms_atom_0.symbol != atom.symbol: 

286 continue 

287 T = atom.position - atoms_atom_0.position 

288 atoms_copy = atoms.copy() 

289 atoms_copy.positions += T 

290 if are_nonpaired_configurations_equal(atoms_copy, target): 

291 return T 

292 return None 

293 

294 

295def are_nonpaired_configurations_equal( 

296 atoms1: ase.Atoms, 

297 atoms2: ase.Atoms, 

298) -> bool: 

299 """ Checks whether two configurations are identical. To be considered 

300 equal the structures must have the same cell metric, elemental 

301 occupation, scaled positions (modulo one), and periodic boundary 

302 conditions. 

303 

304 Unlike the ``__eq__`` operator of :class:`ase.Atoms` the order of the 

305 atoms does not matter. 

306 

307 Parameters 

308 ---------- 

309 atoms1 

310 atoms2 

311 

312 Returns 

313 ------- 

314 True if atoms are equal, False otherwise 

315 

316 """ 

317 # TODO: add tolerance 

318 n_atoms = len(atoms1) 

319 if not (np.allclose(atoms1.cell, atoms2.cell, atol=1e-4) and 319 ↛ 323line 319 didn't jump to line 323 because the condition on line 319 was never true

320 n_atoms == len(atoms2) and 

321 sorted(atoms1.numbers) == sorted(atoms2.numbers) and 

322 all(atoms1.pbc == atoms2.pbc)): 

323 return False 

324 new_cell = (atoms1.cell + atoms2.cell) / 2 

325 pos = [a.position for a in atoms1] + [a.position for a in atoms2] 

326 num = [a.number for a in atoms1] + [a.number for a in atoms2] 

327 s3 = atoms_module.Atoms(cell=new_cell, positions=pos, numbers=num, 

328 pbc=True) 

329 for i in range(n_atoms): 

330 for j in range(n_atoms, len(s3)): 

331 d = s3.get_distance(i, j, mic=True) 

332 if abs(d) < 1e-4: # TODO: tol 

333 if s3[i].number != s3[j].number: 333 ↛ 334line 333 didn't jump to line 334 because the condition on line 333 was never true

334 return False 

335 break 

336 else: 

337 return False 

338 return True