Coverage for hiphive/core/structure_alignment.py: 93%

112 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1import itertools 

2import numpy as np 

3import spglib as spg 

4from . import atoms as atoms_module 

5from ..input_output.logging_tools import logger 

6from .utilities import ase_atoms_to_spglib_tuple 

7 

8logger = logger.getChild('relate_structures') 

9 

10 

11def align_supercell(supercell, prim, symprec=None): 

12 """Rotate and translate a supercell configuration such that it is aligned 

13 with the target primitive cell. 

14 

15 Parameters 

16 ---------- 

17 sc : ase.Atoms 

18 supercell configuration 

19 prim : ase.Atoms 

20 target primitive configuration 

21 symprec : float 

22 precision parameter forwarded to spglib 

23 

24 Returns 

25 ------- 

26 tuple(ase.Atoms, numpy.ndarray, numpy.ndarray) 

27 aligned supercell configuration as well as rotation matrix 

28 (`3x3` array) and translation vector (`3x1` array) that relate 

29 the input to the aligned supercell configuration. 

30 """ 

31 

32 # TODO: Make sure the input is what we expect 

33 

34 # find rotation and translation 

35 R, T = relate_structures(supercell, prim, symprec=symprec) 

36 

37 # Create the aligned system 

38 aligned_supercell = rotate_atoms(supercell, R) 

39 aligned_supercell.translate(T) 

40 aligned_supercell.wrap() 

41 return aligned_supercell, R, T 

42 

43 

44def relate_structures(reference, target, symprec=1e-5): 

45 """Finds rotation and translation operations that align two structures with 

46 periodic boundary conditions. 

47 

48 The rotation and translation in Cartesian coordinates will map the 

49 reference structure onto the target 

50 

51 Aligning reference with target can be achieved via the transformations:: 

52 

53 R, T = relate_structures(atoms_ref, atoms_target) 

54 atoms_ref_rotated = rotate_atoms(atoms_ref, R) 

55 atoms_ref_rotated.translate(T) 

56 atoms_ref_rotated.wrap() 

57 atoms_ref_rotated == atoms_target 

58 

59 Parameters 

60 ---------- 

61 reference : ase.Atoms 

62 The reference structure to be mapped 

63 target : ase.Atoms 

64 The target structure 

65 

66 Returns 

67 ------- 

68 R : numpy.ndarray 

69 rotation matrix in Cartesian coordinates (`3x3` array) 

70 T : numpy.ndarray 

71 translation vector in Cartesian coordinates 

72 """ 

73 

74 logger.debug('Reference atoms:') 

75 _debug_log_atoms(reference) 

76 

77 reference_primitive_cell = get_primitive_cell(reference, symprec=symprec) 

78 

79 logger.debug('Reference primitive cell') 

80 _debug_log_atoms(reference_primitive_cell) 

81 

82 logger.debug('Target atoms:') 

83 _debug_log_atoms(target) 

84 

85 target_primitive_cell = get_primitive_cell(target, symprec=symprec) 

86 

87 logger.debug('Target primitive cell') 

88 _debug_log_atoms(target_primitive_cell) 

89 

90 logger.debug('Sane check that primitive cells can match...') 

91 _assert_structures_match(reference_primitive_cell, target_primitive_cell) 

92 

93 logger.debug('Finding rotations...') 

94 rotations = _find_rotations(reference_primitive_cell.cell, 

95 target_primitive_cell.cell) 

96 

97 logger.debug('Finding transformations...') 

98 for R in rotations: 98 ↛ 106line 98 didn't jump to line 106 because the loop on line 98 didn't complete

99 rotated_reference_primitive_cell = \ 

100 rotate_atoms(reference_primitive_cell, R) 

101 T = _find_translation(rotated_reference_primitive_cell, 

102 target_primitive_cell) 

103 if T is not None: 

104 break 

105 else: 

106 raise Exception(('Found no translation!\n' 

107 'Reference primitive cell basis:\n' 

108 '{}\n' 

109 'Target primitive cell basis:\n' 

110 '{}') 

111 .format(reference_primitive_cell.basis, 

112 target_primitive_cell.basis)) 

113 

114 logger.debug(('Found rotation\n' 

115 '{}\n' 

116 'and translation\n' 

117 '{}') 

118 .format(R, T)) 

119 

120 return R, T 

121 

122 

123def is_rotation(R, cell_metric=None): 

124 """Checks if rotation matrix is orthonormal 

125 

126 A cell metric can be passed of the rotation matrix is in scaled coordinates 

127 

128 Parameters 

129 ---------- 

130 R : numpy.ndarray 

131 rotation matrix (`3x3` array) 

132 cell_metric : numpy.ndarray 

133 cell metric if the rotation is in scaled coordinates 

134 """ 

135 if not cell_metric: 135 ↛ 138line 135 didn't jump to line 138 because the condition on line 135 was always true

136 cell_metric = np.eye(3) 

137 

138 V = cell_metric 

139 V_inv = np.linalg.inv(V) 

140 lhs = np.linalg.multi_dot([V_inv, R.T, V, V.T, R, V_inv.T]) 

141 

142 return np.allclose(lhs, np.eye(3), atol=1e-4) # TODO: tol 

143 

144 

145def _find_rotations(reference_cell_metric, target_cell_metric): 

146 """ Generates all proper and improper rotations aligning two cell 

147 metrics. """ 

148 

149 rotations = [] 

150 V1 = reference_cell_metric 

151 for perm in itertools.permutations([0, 1, 2]): 

152 # Make sure the improper rotations are included 

153 for inv in itertools.product([1, -1], repeat=3): 

154 V2 = np.diag(inv) @ target_cell_metric[perm, :] 

155 R = np.linalg.solve(V1, V2).T 

156 # Make sure the rotation is orthonormal 

157 if is_rotation(R): 

158 for R_tmp in rotations: 

159 if np.allclose(R, R_tmp): # TODO: tol 159 ↛ 160line 159 didn't jump to line 160 because the condition on line 159 was never true

160 break 

161 else: 

162 rotations.append(R) 

163 

164 assert rotations, ('Found no rotations! Reference cell metric:\n' 

165 '{}\n' 

166 'Target cell metric:\n' 

167 '{}').format(reference_cell_metric, target_cell_metric) 

168 

169 logger.debug('Found {} rotations'.format(len(rotations))) 

170 

171 return rotations 

172 

173 

174def _assert_structures_match(ref, prim): 

175 """ Asserts the structures are compatible with respect to number of atoms, 

176 atomic numbers and volume. 

177 

178 TODO: tol 

179 """ 

180 

181 if len(ref) != len(prim): 

182 raise ValueError( 

183 'Number of atoms in reference primitive cell {} does not match ' 

184 'target primitive {}'.format(len(ref), len(prim))) 

185 

186 if sorted(ref.numbers) != sorted(prim.numbers): 

187 raise ValueError('Atomic numbers do not match\nReference: {}\nTarget:' 

188 ' {}\n'.format(ref.numbers, prim.numbers)) 

189 

190 if not np.isclose(ref.get_volume(), prim.get_volume()): 

191 raise ValueError( 

192 'Volume for reference primitive cell {} does not match target ' 

193 'primitive cell {}\n'.format(ref.get_volume(), prim.get_volume())) 

194 

195 

196def get_primitive_cell(atoms, to_primitive=True, no_idealize=True, 

197 symprec=1e-5): 

198 """ Gets primitive cell from spglib. 

199 

200 Parameters 

201 ---------- 

202 atoms : ase.Atoms 

203 atomic structure 

204 to_primitive : bool 

205 passed to spglib 

206 no_idealize : bool 

207 passed to spglib 

208 """ 

209 if not all(atoms.pbc): 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 raise ValueError('atoms must have pbc.') 

211 atoms_as_tuple = ase_atoms_to_spglib_tuple(atoms) 

212 spg_primitive_cell = spg.standardize_cell(atoms_as_tuple, to_primitive=True, 

213 no_idealize=True, 

214 symprec=symprec) 

215 primitive_cell = atoms_module.Atoms(cell=spg_primitive_cell[0], 

216 scaled_positions=spg_primitive_cell[1], 

217 numbers=spg_primitive_cell[2], 

218 pbc=True) 

219 return primitive_cell 

220 

221 

222def _debug_log_atoms(atoms): 

223 logger.debug('cell:\n{}'.format(atoms.cell)) 

224 logger.debug('spos:\n{}'.format(atoms.get_scaled_positions())) 

225 logger.debug('pos:\n{}'.format(atoms.positions)) 

226 logger.debug('numbers:\n{}'.format(atoms.numbers)) 

227 

228 

229def rotate_atoms(atoms, rotation): 

230 """Rotates the cell and positions of Atoms and returns a copy 

231 

232 Parameters 

233 ---------- 

234 atoms : ase.Atoms 

235 atomic structure 

236 rotation : numpy.ndarray 

237 rotation matrix (`3x3` array) 

238 """ 

239 cell = np.dot(rotation, atoms.cell.T).T 

240 positions = np.dot(rotation, atoms.positions.T).T 

241 return atoms_module.Atoms(cell=cell, positions=positions, 

242 numbers=atoms.numbers, pbc=atoms.pbc) 

243 

244 

245def _find_translation(reference, target): 

246 """Returns the translation between two compatible atomic structures. 

247 

248 The two structures must describe the same structure when infinitely 

249 repeated but differ by a translation. 

250 

251 Parameters 

252 ---------- 

253 reference : ase.Atoms 

254 target : ase.Atoms 

255 

256 Returns 

257 ------- 

258 numpy.ndarray or None 

259 translation vector or `None` if structures are incompatible 

260 """ 

261 

262 atoms = atoms_module.Atoms(cell=target.cell, 

263 positions=reference.positions, 

264 numbers=reference.numbers, 

265 pbc=True) 

266 atoms.wrap() 

267 

268 atoms_atom_0 = atoms[0] 

269 for atom in target: 

270 if atoms_atom_0.symbol != atom.symbol: 

271 continue 

272 T = atom.position - atoms_atom_0.position 

273 atoms_copy = atoms.copy() 

274 atoms_copy.positions += T 

275 if are_nonpaired_configurations_equal(atoms_copy, target): 

276 return T 

277 return None 

278 

279 

280def are_nonpaired_configurations_equal(atoms1, atoms2): 

281 """ Checks whether two configurations are identical. To be considered 

282 equal the structures must have the same cell metric, elemental 

283 occupation, scaled positions (modulo one), and periodic boundary 

284 conditions. 

285 

286 Unlike the ``__eq__`` operator of :class:`ase.Atoms` the order of the 

287 atoms does not matter. 

288 

289 Parameters 

290 ---------- 

291 atoms1 : ase.Atoms 

292 atoms2 : ase.Atoms 

293 

294 Returns 

295 ------- 

296 bool 

297 True if atoms are equal, False otherwise 

298 

299 TODO: tol 

300 """ 

301 n_atoms = len(atoms1) 

302 if not (np.allclose(atoms1.cell, atoms2.cell, atol=1e-4) and 302 ↛ 306line 302 didn't jump to line 306 because the condition on line 302 was never true

303 n_atoms == len(atoms2) and 

304 sorted(atoms1.numbers) == sorted(atoms2.numbers) and 

305 all(atoms1.pbc == atoms2.pbc)): 

306 return False 

307 new_cell = (atoms1.cell + atoms2.cell) / 2 

308 pos = [a.position for a in atoms1] + [a.position for a in atoms2] 

309 num = [a.number for a in atoms1] + [a.number for a in atoms2] 

310 s3 = atoms_module.Atoms(cell=new_cell, positions=pos, numbers=num, 

311 pbc=True) 

312 for i in range(n_atoms): 

313 for j in range(n_atoms, len(s3)): 

314 d = s3.get_distance(i, j, mic=True) 

315 if abs(d) < 1e-4: # TODO: tol 

316 if s3[i].number != s3[j].number: 316 ↛ 317line 316 didn't jump to line 317 because the condition on line 316 was never true

317 return False 

318 break 

319 else: 

320 return False 

321 return True