Coverage for hiphive/structure_container.py: 95%

192 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-01 17:04 +0000

1""" 

2This module provides functionality for storing structures and their fit 

3matrices together with target forces and displacements 

4""" 

5 

6import tarfile 

7import numpy as np 

8from collections import OrderedDict 

9from typing import Union, IO 

10from ase import Atoms 

11from ase.calculators.singlepoint import SinglePointCalculator 

12 

13from .input_output.read_write_files import (add_items_to_tarfile_hdf5, 

14 add_items_to_tarfile_pickle, 

15 add_items_to_tarfile_custom, 

16 add_list_to_tarfile_custom, 

17 read_items_hdf5, 

18 read_items_pickle, 

19 read_list_custom) 

20 

21from .cluster_space import ClusterSpace 

22from .force_constant_model import ForceConstantModel 

23from .input_output.logging_tools import logger 

24logger = logger.getChild('structure_container') 

25 

26 

27class StructureContainer: 

28 """ 

29 This class serves as a container for structures as well as associated 

30 fit properties and fit matrices. 

31 

32 Parameters 

33 ----------- 

34 cs : ClusterSpace 

35 Cluster space that is the basis for the container. 

36 fit_structure_list : list[FitStructure] 

37 Structures to be added to the container. 

38 """ 

39 

40 def __init__(self, cs, fit_structure_list=None): 

41 """ 

42 Attributes 

43 ----------- 

44 _cs : ClusterSpace 

45 cluster space that is the basis for the container 

46 _structure_list : list(FitStructure) 

47 structures to add to container 

48 _previous_fcm : ForceConstantModel 

49 FCM object used for last fit matrix calculation; check will be 

50 carried out to decide if this FCM can be used for a new structure 

51 or not, which often enables a considerable speed-up 

52 """ 

53 self._cs = cs.copy() 

54 self._previous_fcm = None 

55 

56 # Add atoms from atoms_list 

57 self._structure_list = [] 

58 if fit_structure_list is not None: 

59 for fit_structure in fit_structure_list: 

60 if not isinstance(fit_structure, FitStructure): 60 ↛ 61line 60 didn't jump to line 61 because the condition on line 60 was never true

61 raise TypeError('Can only add FitStructures') 

62 self._structure_list.append(fit_structure) 

63 

64 def __len__(self): 

65 return len(self._structure_list) 

66 

67 def __getitem__(self, ind): 

68 return self._structure_list[ind] 

69 

70 @property 

71 def data_shape(self) -> tuple[int]: 

72 """ Tuple of integers representing the shape of the fit data 

73 matrix. """ 

74 n_cols = self._cs.n_dofs 

75 n_rows = sum(len(fs) * 3 for fs in self) 

76 if n_rows == 0: 

77 return None 

78 return n_rows, n_cols 

79 

80 @property 

81 def cluster_space(self) -> ClusterSpace: 

82 """ Copy of the cluster space the structure 

83 container is based on""" 

84 return self._cs.copy() 

85 

86 @staticmethod 

87 def read(fileobj: Union[str, IO], read_structures: bool = True): 

88 """Restore a :class:`StructureContainer` object from file. 

89 

90 Parameters 

91 ---------- 

92 f 

93 Name of input file (`str`) or stream to load from (file object). 

94 read_structures 

95 If `True` the structures will be read; if `False` only the cluster 

96 space will be read. 

97 """ 

98 if isinstance(fileobj, str): 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true

99 tar_file = tarfile.open(mode='r', name=fileobj) 

100 else: 

101 tar_file = tarfile.open(mode='r', fileobj=fileobj) 

102 

103 # Read clusterspace 

104 f = tar_file.extractfile('cluster_space') 

105 cs = ClusterSpace.read(f) 

106 

107 # Read fitstructures 

108 fit_structure_list = None 

109 if read_structures: 

110 fit_structure_list = read_list_custom(tar_file, 'fit_structure', FitStructure.read) 

111 

112 # setup StructureContainer 

113 sc = StructureContainer(cs, fit_structure_list) 

114 

115 # Read previous FCM if it exists 

116 if 'previous_fcm' in tar_file.getnames(): 

117 f = tar_file.extractfile('previous_fcm') 

118 fcm = ForceConstantModel.read(f) 

119 sc._previous_fcm = fcm 

120 

121 return sc 

122 

123 def write(self, f: Union[str, IO]): 

124 """Write a :class:`StructureContainer` instance to file. 

125 

126 Parameters 

127 ---------- 

128 f 

129 Name of input file (`str`) or stream to write to (file object). 

130 """ 

131 

132 if isinstance(f, str): 132 ↛ 133line 132 didn't jump to line 133 because the condition on line 132 was never true

133 tar_file = tarfile.open(mode='w', name=f) 

134 else: 

135 tar_file = tarfile.open(mode='w', fileobj=f) 

136 

137 # save cs and previous_fcm (if it exists) 

138 custom_items = dict(cluster_space=self._cs) 

139 if self._previous_fcm is not None: 

140 custom_items['previous_fcm'] = self._previous_fcm 

141 add_items_to_tarfile_custom(tar_file, custom_items) 

142 

143 # save fit structures 

144 add_list_to_tarfile_custom(tar_file, self._structure_list, 'fit_structure') 

145 

146 tar_file.close() 

147 

148 def add_structure(self, atoms: Atoms, **meta_data): 

149 """Add a structure to the container. 

150 

151 Note that custom information about the atoms object may not be 

152 stored inside, for example an ASE 

153 :class:`SinglePointCalculator` will not be kept. 

154 

155 Parameters 

156 ---------- 

157 atoms 

158 The structure to be added; the Atoms object must contain 

159 supplementary per-atom arrays with displacements and forces. 

160 meta_data 

161 Dict with meta_data about the atoms. 

162 """ 

163 

164 atoms_copy = atoms.copy() 

165 

166 # atoms object must contain displacements 

167 if 'displacements' not in atoms_copy.arrays.keys(): 

168 raise ValueError('Atoms must have displacements array') 

169 

170 # atoms object must contain forces 

171 if 'forces' not in atoms_copy.arrays.keys(): 

172 if isinstance(atoms.calc, SinglePointCalculator): 

173 atoms_copy.new_array('forces', atoms.get_forces()) 

174 else: 

175 raise ValueError('Atoms must have forces') 

176 

177 # check if an identical atoms object already exists in the container 

178 for i, structure in enumerate(self._structure_list): 

179 if are_configurations_equal(atoms_copy, structure.atoms): 

180 raise ValueError('Atoms is identical to structure {}'.format(i)) 

181 

182 logger.debug('Adding structure') 

183 M = self._compute_fit_matrix(atoms) 

184 structure = FitStructure(atoms_copy, M, **meta_data) 

185 self._structure_list.append(structure) 

186 

187 def delete_all_structures(self): 

188 """ Remove all current structures in :class:`StructureContainer`. """ 

189 self._structure_list = [] 

190 

191 def get_fit_data( 

192 self, 

193 structures: Union[list[int], tuple[int]] = None, 

194 ) -> tuple[np.ndarray, np.ndarray]: 

195 """Return fit data for structures. The fit matrices and target forces 

196 for the structures are stacked into NumPy arrays. 

197 

198 Parameters 

199 ---------- 

200 structures 

201 List of integers corresponding to structure indices. Defaults to 

202 `None` and in that case returns all fit data available. 

203 

204 Returns 

205 ------- 

206 Stacked fit matrices, stacked target forces for the structures. 

207 """ 

208 if structures is None: 

209 structures = range(len(self)) 

210 

211 M_list, f_list = [], [] 

212 for i in structures: 

213 M_list.append(self._structure_list[i].fit_matrix) 

214 f_list.append(self._structure_list[i].forces.flatten()) 

215 

216 if len(M_list) == 0: 

217 return None 

218 return np.vstack(M_list), np.hstack(f_list) 

219 

220 def __str__(self): 

221 if len(self._structure_list) > 0: 

222 return self._get_str_structure_list() 

223 else: 

224 return 'Empty StructureContainer' 

225 

226 def __repr__(self): 

227 return 'StructureContainer({!r}, {!r})'.format( 

228 self._cs, self._structure_list) 

229 

230 def _get_str_structure_list(self): 

231 """ Return formatted string of the structure list """ 

232 def str_structure(index, structure): 

233 fields = OrderedDict([ 

234 ('index', '{:^4}'.format(index)), 

235 ('num-atoms', '{:^5}'.format(len(structure))), 

236 ('avg-disp', '{:7.4f}' 

237 .format(np.mean([np.linalg.norm(d) for d in 

238 structure.displacements]))), 

239 ('avg-force', '{:7.4f}' 

240 .format(np.mean([np.linalg.norm(f) for f in 

241 structure.forces]))), 

242 ('max-force', '{:7.4f}' 

243 .format(np.max([np.linalg.norm(f) for f in 

244 structure.forces])))]) 

245 s = [] 

246 for name, value in fields.items(): 

247 n = max(len(name), len(value)) 

248 if index < 0: 

249 s += ['{s:^{n}}'.format(s=name, n=n)] 

250 else: 

251 s += ['{s:^{n}}'.format(s=value, n=n)] 

252 return ' | '.join(s) 

253 

254 # table width 

255 dummy = self._structure_list[0] 

256 n = len(str_structure(-1, dummy)) 

257 

258 # table header 

259 s = [] 

260 s.append(' Structure Container '.center(n, '=')) 

261 s += ['{:22} : {}'.format('Total number of structures', len(self))] 

262 _, target_forces = self.get_fit_data() 

263 s += ['{:22} : {}'.format('Number of force components', len(target_forces))] 

264 s.append(''.center(n, '-')) 

265 s.append(str_structure(-1, dummy)) 

266 s.append(''.center(n, '-')) 

267 

268 # table body 

269 for i, structure in enumerate(self._structure_list): 

270 s.append(str_structure(i, structure)) 

271 s.append(''.center(n, '=')) 

272 return '\n'.join(s) 

273 

274 def _compute_fit_matrix(self, atoms): 

275 """ Compute fit matrix for a single atoms object """ 

276 logger.debug('Computing fit matrix') 

277 if atoms != getattr(self._previous_fcm, 'atoms', None): 

278 logger.debug(' Building new FCM object') 

279 self._previous_fcm = ForceConstantModel(atoms, self._cs) 

280 else: 

281 logger.debug(' Reusing old FCM object') 

282 return self._previous_fcm.get_fit_matrix(atoms.get_array('displacements')) 

283 

284 

285class FitStructure: 

286 """This class holds a structure with displacements and forces as well as 

287 the fit matrix. 

288 

289 Parameters 

290 ---------- 

291 atoms : Atoms 

292 Supercell structure. 

293 fit_matrix 

294 Fit matrix, `N, M` array with `N = 3 * len(atoms)`. 

295 meta_data 

296 Any meta data that needs to be stored in the :class:`FitStructure`. 

297 """ 

298 

299 def __init__(self, atoms, fit_matrix, **meta_data): 

300 if 3 * len(atoms) != fit_matrix.shape[0]: 300 ↛ 301line 300 didn't jump to line 301 because the condition on line 300 was never true

301 raise ValueError('fit matrix not compatible with atoms') 

302 self._atoms = atoms 

303 self._fit_matrix = fit_matrix 

304 self.meta_data = meta_data 

305 

306 @property 

307 def fit_matrix(self) -> np.ndarray: 

308 """ The fit matrix. """ 

309 return self._fit_matrix 

310 

311 @property 

312 def atoms(self) -> Atoms: 

313 """ Supercell structure. """ 

314 return self._atoms.copy() 

315 

316 @property 

317 def forces(self) -> np.ndarray: 

318 """ Forces. """ 

319 return self._atoms.get_array('forces') 

320 

321 @property 

322 def displacements(self) -> np.ndarray: 

323 """ Atomic displacements. """ 

324 return self._atoms.get_array('displacements') 

325 

326 def __getattr__(self, key): 

327 """ Accesses meta_data if possible and returns value. """ 

328 if key not in self.meta_data.keys(): 

329 return super().__getattribute__(key) 

330 return self.meta_data[key] 

331 

332 def __len__(self): 

333 return len(self._atoms) 

334 

335 def __str__(self): 

336 s = [] 

337 s.append(' FitStructure '.center(65, '=')) 

338 s.append('Formula: {}'.format(self.atoms.get_chemical_formula())) 

339 s.append(('Cell:' + '\n [{:9.5f} {:9.5f} {:9.5f}]'*3).format( 

340 *self.atoms.cell[0], *self.atoms.cell[1], *self.atoms.cell[2])) 

341 s.append('Atoms (positions, displacements, forces):') 

342 for atom, disp, force in zip(self.atoms, self.displacements, self.forces): 

343 array_fmt = '[ {:9.5f} {:9.5f} {:9.5f} ]' 

344 row_str = '{:3d} {}'.format(atom.index, atom.symbol) 

345 row_str += array_fmt.format(*atom.position) 

346 row_str += array_fmt.format(*disp) 

347 row_str += array_fmt.format(*force) 

348 s.append(row_str) 

349 return '\n'.join(s) 

350 

351 def __repr__(self): 

352 return 'FitStructure({!r}, ..., {!r})'.format(self.atoms, self.meta_data) 

353 

354 def write(self, fileobj): 

355 """ Write the instance to file. 

356 

357 Parameters 

358 ---------- 

359 fileobj : str or file object 

360 name of input file (str) or stream to write to (file object) 

361 """ 

362 if isinstance(fileobj, str): 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true

363 tar_file = tarfile.open(name=fileobj, mode='w') 

364 else: 

365 tar_file = tarfile.open(fileobj=fileobj, mode='w') 

366 

367 items_pickle = dict(atoms=self._atoms, meta_data=self.meta_data) 

368 items_hdf5 = dict(fit_matrix=self.fit_matrix) 

369 

370 add_items_to_tarfile_pickle(tar_file, items_pickle, 'items.pickle') 

371 add_items_to_tarfile_hdf5(tar_file, items_hdf5, 'fit_matrix.hdf5') 

372 

373 tar_file.close() 

374 

375 @staticmethod 

376 def read(fileobj: Union[str, IO], read_fit_matrix: bool = True): 

377 """ Read a :class:`FitStructure` instance from file. 

378 

379 Parameters 

380 ---------- 

381 fileobj 

382 Name of input file (`str`) or stream to read from (file object). 

383 read_fit_matrix 

384 Whether or not to read the fit matrix. 

385 """ 

386 if isinstance(fileobj, str): 386 ↛ 387line 386 didn't jump to line 387 because the condition on line 386 was never true

387 tar_file = tarfile.open(mode='r', name=fileobj) 

388 else: 

389 tar_file = tarfile.open(mode='r', fileobj=fileobj) 

390 

391 items = read_items_pickle(tar_file, 'items.pickle') 

392 fit_matrix = read_items_hdf5(tar_file, 'fit_matrix.hdf5')['fit_matrix'] 

393 

394 return FitStructure(items['atoms'], fit_matrix, **items['meta_data']) 

395 

396 

397def are_configurations_equal(atoms1: Atoms, atoms2: Atoms, tol: float = 1e-10): 

398 """ Compare if two configurations are equal within some tolerance. This 

399 includes checking all available arrays in the two atoms objects. 

400 

401 Parameters 

402 ---------- 

403 atoms1 

404 atoms2 

405 tol 

406 Numerical tolerance imposed during comparison. 

407 

408 Returns 

409 ------- 

410 `True` if atoms are equal, `False` otherwise. 

411 """ 

412 

413 # pbc 

414 if not all(atoms1.pbc == atoms2.pbc): 

415 return False 

416 

417 # cell 

418 if not np.allclose(atoms1.cell, atoms2.cell, atol=tol, rtol=0.0): 

419 return False 

420 

421 # arrays 

422 if not len(atoms1.arrays.keys()) == len(atoms2.arrays.keys()): 

423 return False 

424 for key, array1 in atoms1.arrays.items(): 

425 if key not in atoms2.arrays.keys(): 

426 return False 

427 if not np.allclose(array1, atoms2.arrays[key], atol=tol, rtol=0.0): 

428 return False 

429 

430 # passed all test, atoms must be equal 

431 return True