Coverage for hiphive/cluster_space.py: 98%

236 statements  

« prev     ^ index     » next       coverage.py v7.10.1, created at 2025-08-01 17:04 +0000

1""" 

2Contains the ClusterSpace object central to hiPhive 

3""" 

4 

5 

6import tarfile 

7from collections import OrderedDict 

8from copy import deepcopy 

9from typing import BinaryIO, TextIO, Union 

10 

11import numpy as np 

12from ase.data import chemical_symbols 

13 

14from .core.cluster_space_builder import build_cluster_space 

15from .core.atoms import Atoms 

16from .core.orbits import Orbit 

17from .core.utilities import BiMap 

18from .cluster_space_data import ClusterSpaceData 

19from .input_output.logging_tools import logger 

20from .input_output.read_write_files import (add_items_to_tarfile_pickle, 

21 add_items_to_tarfile_custom, 

22 add_list_to_tarfile_custom, 

23 read_items_pickle, 

24 read_list_custom) 

25from .cutoffs import Cutoffs, CutoffMaximumBody, BaseClusterFilter 

26 

27from .config import Config 

28logger = logger.getChild('ClusterSpace') 

29 

30 

31class ClusterSpace: 

32 """Primitive object for handling clusters and force constants of a structure. 

33 

34 Parameters 

35 ---------- 

36 prototype_structure : ase.Atoms 

37 Prototype structure; spglib will be used to find a suitable cell based 

38 on this structure unless the cell is already a primitive cell. 

39 cutoffs : list or Cutoffs 

40 Cutoff radii for different orders starting with second order. 

41 cluster_filter : ClusterFilter 

42 Accepts a subclass of hiphive.filters.BaseClusterFilter to further 

43 control which orbits to include. 

44 config : Config object 

45 A configuration object that holds information on how the cluster space 

46 should be built, e.g., values for tolerances and specifications 

47 regarding the handling of acoustic sum rules; if ``config`` is 

48 not given then the keyword arguments that follow below can be 

49 used for configuration. 

50 acoustic_sum_rules : bool 

51 If `True` the aucostic sum rules will be enforced by constraining the 

52 parameters. 

53 symprec : float 

54 Numerical precision that will be used for analyzing the symmetry (this 

55 parameter will be forwarded to `spglib <https://phonopy.github.io/spglib/>`_). 

56 length_scale : float 

57 This will be used as a normalization constant for the eigentensors. 

58 

59 Examples 

60 -------- 

61 

62 To instantiate a :class:`ClusterSpace` object one has to specify a 

63 prototype structure and cutoff radii for each cluster order that 

64 should be included. For example the following snippet will set up 

65 a :class:`ClusterSpace` object for a body-centered-cubic (BCC) 

66 structure including second order terms up to a distance of 5 A and 

67 third order terms up to a distance of 4 A. 

68 

69 >>> from ase.build import bulk 

70 >>> from hiphive import ClusterSpace 

71 >>> prim = bulk('W') 

72 >>> cs = ClusterSpace(prim, [5.0, 4.0]) 

73 

74 """ 

75 # TODO: This class probably needs some more documentation 

76 

77 def __init__(self, prototype_structure, cutoffs, config=None, 

78 cluster_filter=None, **kwargs): 

79 

80 if not all(prototype_structure.pbc): 

81 raise ValueError('prototype_structure must have pbc.') 

82 

83 if isinstance(cutoffs, Cutoffs): 

84 self._cutoffs = cutoffs 

85 elif isinstance(cutoffs, list): 

86 self._cutoffs = CutoffMaximumBody(cutoffs, len(cutoffs) + 1) 

87 else: 

88 raise TypeError('cutoffs must be a list or a Cutoffs object') 

89 

90 if config is None: 

91 config = Config(**kwargs) 

92 else: 

93 if not isinstance(config, Config): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true

94 raise TypeError('config kw must be of type {}'.format(Config)) 

95 if len(kwargs) > 0: 

96 raise ValueError('use either Config or kwargs, not both') 

97 self._config = config 

98 

99 if cluster_filter is None: 

100 self._cluster_filter = BaseClusterFilter() 

101 else: 

102 self._cluster_filter = cluster_filter 

103 

104 self._atom_list = None 

105 self._cluster_list = None 

106 self._symmetry_dataset = None 

107 self._permutations = None 

108 self._prim = None 

109 self._orbits = None 

110 

111 self._constraint_vectors = None 

112 # TODO: How to handle the constraint matrices? Should they even be 

113 # stored? 

114 self._constraint_matrices = None 

115 # Is this the best way or should the prim be instantiated separately? 

116 

117 build_cluster_space(self, prototype_structure) 

118 self.summary = ClusterSpaceData(self) 

119 

120 @property 

121 def n_dofs(self) -> int: 

122 """Number of free parameters in the model. 

123 

124 If the sum rules are not enforced the number of DOFs is the same as 

125 the total number of eigentensors in all orbits. 

126 """ 

127 return self._get_n_dofs() 

128 

129 @property 

130 def cutoffs(self) -> Cutoffs: 

131 """ Cutoffs used for constructing the cluster space. """ 

132 return deepcopy(self._cutoffs) 

133 

134 @property 

135 def symprec(self) -> float: 

136 """ Symprec value used when constructing the cluster space. """ 

137 return self._config['symprec'] 

138 

139 @property 

140 def acoustic_sum_rules(self) -> bool: 

141 """ True if acoustic sum rules are enforced. """ 

142 return self._config['acoustic_sum_rules'] 

143 

144 @property 

145 def length_scale(self) -> bool: 

146 """ Normalization constant of the force constants. """ 

147 return self._config['length_scale'] 

148 

149 @property 

150 def primitive_structure(self) -> Atoms: 

151 """ Structure of the lattice. """ 

152 return self._prim.copy() 

153 

154 @property 

155 def spacegroup(self) -> str: 

156 """ Space group of the lattice structure obtained from spglib. """ 

157 return f'{self._symmetry_dataset.international} ({self._symmetry_dataset.number})' 

158 

159 @property 

160 def wyckoff_sites(self) -> list: 

161 """ Wyckoff sites in the primitive cell. """ 

162 return self._symmetry_dataset.equivalent_atoms 

163 

164 @property 

165 def rotation_matrices(self) -> list[np.ndarray]: 

166 """ Symmetry elements (`3x3` matrices) representing rotations. """ 

167 return self._symmetry_dataset.rotations.copy() 

168 

169 @property 

170 def translation_vectors(self) -> list[np.ndarray]: 

171 """ Symmetry elements representing translations. """ 

172 # TODO: bug incoming! 

173 return (self._symmetry_dataset.translations % 1).copy() 

174 

175 @property 

176 def permutations(self) -> list[np.ndarray]: 

177 """ Lookup for permutation references. """ 

178 return deepcopy(self._permutations) 

179 

180 @property 

181 def atom_list(self) -> BiMap: 

182 """ Atoms inside the cutoff relative to the of the center cell. """ 

183 return self._atom_list 

184 

185 @property 

186 def cluster_list(self) -> BiMap: 

187 """ Clusters possible within the cutoff. """ 

188 return self._cluster_list 

189 

190 @property 

191 def orbits(self) -> list[Orbit]: # TODO: add __getitem__ method 

192 """ Orbits associated with the lattice structure. """ 

193 return self._orbits 

194 

195 @property 

196 def orbit_data(self) -> list[dict]: 

197 """ Detailed information for each orbit, e.g., cluster radius and atom types. 

198 """ 

199 data = [] 

200 p = 0 

201 for orbit_index, orbit in enumerate(self.orbits): 

202 d = {} 

203 d['index'] = orbit_index 

204 d['order'] = orbit.order 

205 d['radius'] = orbit.radius 

206 d['maximum_distance'] = orbit.maximum_distance 

207 d['n_clusters'] = len(orbit.orientation_families) 

208 d['eigentensors'] = orbit.eigentensors 

209 d['n_parameters'] = len(d['eigentensors']) 

210 

211 types, wyckoff_sites = [], [] 

212 for atom_index in self.cluster_list[orbit.prototype_index]: 

213 atom = self.atom_list[atom_index] 

214 types.append(self.primitive_structure.numbers[atom.site]) 

215 wyckoff_sites.append(self.wyckoff_sites[atom.site]) 

216 d['prototype_cluster'] = self.cluster_list[orbit.prototype_index] 

217 d['prototype_atom_types'] = tuple(types) 

218 d['prototype_wyckoff_sites'] = tuple(wyckoff_sites) 

219 

220 d['geometrical_order'] = len(set(d['prototype_cluster'])) 

221 d['parameter_indices'] = np.arange(p, p + len(orbit.eigentensors)) 

222 

223 p += len(orbit.eigentensors) 

224 data.append(d) 

225 

226 return data 

227 

228 def get_parameter_indices(self, order: int) -> list[int]: 

229 """ 

230 Returns a list of the parameter indices associated with the requested 

231 order. 

232 

233 Parameters 

234 ---------- 

235 order 

236 Order for which to return the parameter indices. 

237 

238 Returns 

239 ------- 

240 List of parameter indices associated with the requested order. 

241 

242 Raises 

243 ------ 

244 ValueError 

245 If the order is not included in the cluster space. 

246 """ 

247 order = int(order) 

248 if order not in self.cutoffs.orders: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true

249 raise ValueError('Order must be in {}'.format(self.cutoffs.orders)) 

250 min_param = 0 

251 max_param = 0 

252 for orbit in self.orbits: 

253 if orbit.order < order: 

254 min_param += len(orbit.eigentensors) 

255 max_param = min_param 

256 elif orbit.order == order: 

257 max_param += len(orbit.eigentensors) 

258 else: 

259 break 

260 rows, cols = self._cvs.nonzero() 

261 parameters = set() 

262 for r, c in zip(rows, cols): 

263 if min_param <= r < max_param: 

264 parameters.add(c) 

265 for r, c in zip(rows, cols): 

266 if c in parameters: 

267 assert min_param <= r < max_param, 'The internals are broken!' 

268 

269 return sorted(parameters) 

270 

271 def get_n_dofs_by_order(self, order: int) -> int: 

272 """ Returns number of degrees of freedom for the given order. 

273 

274 Parameters 

275 ---------- 

276 order 

277 Order for which to return the number of DOFs. 

278 

279 Returns 

280 ------- 

281 Number of degrees of freedom. 

282 """ 

283 return len(self.get_parameter_indices(order=order)) 

284 

285 def _get_n_dofs(self): 

286 """ Returns the number of degrees of freedom. """ 

287 return self._cvs.shape[1] 

288 

289 def _map_parameters(self, parameters): 

290 """ Maps irreducible parameters to the real parameters associated with 

291 the eigentensors. 

292 """ 

293 if len(parameters) != self.n_dofs: 293 ↛ 294line 293 didn't jump to line 294 because the condition on line 293 was never true

294 raise ValueError('Invalid number of parameters, please provide {} ' 

295 'parameters'.format(self.n_dofs)) 

296 return self._cvs.dot(parameters) 

297 

298 def print_tables(self): 

299 """ Prints information concerning the underlying cluster space to stdout, including, 

300 e.g., the number of cluster, orbits, and parameters by order and number of bodies. """ 

301 self.summary.print_tables() 

302 

303 def print_orbits(self): 

304 """ Prints a list of all orbits. """ 

305 orbits = self.orbit_data 

306 

307 def str_orbit(index, orbit): 

308 elements = ' '.join(chemical_symbols[n] for n in 

309 orbit['prototype_atom_types']) 

310 fields = OrderedDict([ 

311 ('index', '{:^3}'.format(index)), 

312 ('order', '{:^3}'.format(orbit['order'])), 

313 ('elements', '{:^18}'.format(elements)), 

314 ('radius', '{:^8.4f}'.format(orbit['radius'])), 

315 ('prototype', '{:^18}'.format(str(orbit['prototype_cluster']))), 

316 ('clusters', '{:^4}'.format(orbit['n_clusters'])), 

317 ('parameters', '{:^3}'.format(len(orbit['eigentensors']))), 

318 ]) 

319 

320 s = [] 

321 for name, value in fields.items(): 

322 n = max(len(name), len(value)) 

323 if index < 0: 

324 s += ['{s:^{n}}'.format(s=name, n=n)] 

325 else: 

326 s += ['{s:^{n}}'.format(s=value, n=n)] 

327 return ' | '.join(s) 

328 

329 # table header 

330 width = max(len(str_orbit(-1, orbits[-1])), len(str_orbit(0, orbits[-1]))) 

331 print(' List of Orbits '.center(width, '=')) 

332 print(str_orbit(-1, orbits[0])) 

333 print(''.center(width, '-')) 

334 

335 # table body 

336 for i, orbit in enumerate(orbits): 

337 print(str_orbit(i, orbit)) 

338 print(''.center(width, '=')) 

339 

340 def __str__(self): 

341 

342 def str_order(order, header: bool = False): 

343 formats = {'order': '{:2}', 

344 'n_orbits': '{:5}', 

345 'n_clusters': '{:5}'} 

346 s = [] 

347 for name, value in order.items(): 

348 str_repr = formats[name].format(value) 

349 n = max(len(name), len(str_repr)) 

350 if header: 

351 s += ['{s:^{n}}'.format(s=name, n=n)] 

352 else: 

353 s += ['{s:^{n}}'.format(s=str_repr, n=n)] 

354 return ' | '.join(s) 

355 

356 # collect data 

357 orbits = self.orbit_data 

358 orders = self.cutoffs.orders 

359 

360 order_data = {o: dict(order=o, n_orbits=0, n_clusters=0) for o in orders} 

361 for orbit in orbits: 

362 o = orbit['order'] 

363 order_data[o]['n_orbits'] += 1 

364 order_data[o]['n_clusters'] += orbit['n_clusters'] 

365 

366 # prototype with max order to find column width 

367 max_order = max(orders) 

368 prototype = order_data[max_order] 

369 n = max(len(str_order(prototype)), 54) 

370 

371 # basic information 

372 s = [] 

373 s.append(' Cluster Space '.center(n, '=')) 

374 data = [('Spacegroup', self.spacegroup), 

375 ('symprec', self.symprec), 

376 ('Sum rules', self.acoustic_sum_rules), 

377 ('Length scale', self.length_scale), 

378 ('Cutoffs', self.cutoffs), 

379 ('Cell', self.primitive_structure.cell), 

380 ('Basis', self.primitive_structure.basis), 

381 ('Numbers', self.primitive_structure.numbers), 

382 ('Total number of orbits', len(orbits)), 

383 ('Total number of clusters', 

384 sum([order_data[order]['n_clusters'] for order in orders])), 

385 ('Total number of parameters', self._get_n_dofs() 

386 )] 

387 for field, value in data: 

388 if str(value).count('\n') > 1: 

389 s.append('{:26} :\n{}'.format(field, value)) 

390 else: 

391 s.append('{:26} : {}'.format(field, value)) 

392 

393 # table header 

394 s.append(''.center(n, '-')) 

395 s.append(str_order(prototype, header=True)) 

396 s.append(''.center(n, '-')) 

397 for order in orders: 

398 s.append(str_order(order_data[order]).rstrip()) 

399 s.append(''.center(n, '=')) 

400 return '\n'.join(s) 

401 

402 def __repr__(self): 

403 s = 'ClusterSpace({!r}, {!r}, {!r})' 

404 return s.format(self.primitive_structure, self.cutoffs, self._config) 

405 

406 def copy(self): 

407 return deepcopy(self) 

408 

409 def write(self, fileobj: Union[str, BinaryIO, TextIO]): 

410 """ Writes cluster space to file. 

411 

412 The instance is saved into a custom format based on tar-files. The 

413 resulting file will be a valid tar file and can be browsed by by a tar 

414 reader. The included objects are themself either pickles, npz or other 

415 tars. 

416 

417 Parameters 

418 ---------- 

419 fileobj 

420 If the input is a string a tar archive will be created in the 

421 current directory. Otherwise the input must be a valid file 

422 like object. 

423 """ 

424 # Create a tar archive 

425 if isinstance(fileobj, str): 

426 tar_file = tarfile.open(name=fileobj, mode='w') 

427 else: 

428 tar_file = tarfile.open(fileobj=fileobj, mode='w') 

429 

430 # Attributes in pickle format 

431 pickle_attributes = ['_config', 

432 '_symmetry_dataset', '_permutations', 

433 '_atom_list', '_cluster_list'] 

434 items_pickle = dict() 

435 for attribute in pickle_attributes: 

436 items_pickle[attribute] = self.__getattribute__(attribute) 

437 add_items_to_tarfile_pickle(tar_file, items_pickle, 'attributes') 

438 

439 # Constraint matrices and vectors in hdf5 format 

440 items = dict(cvs=self._cvs) 

441 add_items_to_tarfile_pickle(tar_file, items, 'constraint_vectors') 

442 

443 # Cutoffs and prim with their builtin write/read functions 

444 items_custom = {'_cutoffs': self._cutoffs, '_prim': self._prim} 

445 add_items_to_tarfile_custom(tar_file, items_custom) 

446 

447 # Orbits 

448 add_list_to_tarfile_custom(tar_file, self._orbits, 'orbits') 

449 add_list_to_tarfile_custom(tar_file, self._dropped_orbits, 

450 'dropped_orbits') 

451 

452 # Done! 

453 tar_file.close() 

454 

455 def read(f: Union[str, BinaryIO, TextIO]): 

456 """ Reads a cluster space from file. 

457 

458 Parameters 

459 ---------- 

460 f 

461 Name of input file (`str`) or stream to load from (file object). 

462 """ 

463 

464 # Instantiate empty cs obj. 

465 cs = ClusterSpace.__new__(ClusterSpace) 

466 

467 # Load from file on disk or file-like 

468 if type(f) is str: 

469 tar_file = tarfile.open(mode='r', name=f) 

470 else: 

471 tar_file = tarfile.open(mode='r', fileobj=f) 

472 

473 # Attributes 

474 attributes = read_items_pickle(tar_file, 'attributes') 

475 for name, value in attributes.items(): 

476 cs.__setattr__(name, value) 

477 

478 # Load the constraint matrices into their dict 

479 items = read_items_pickle(tar_file, 'constraint_vectors') 

480 cs._cvs = items['cvs'] 

481 

482 # Cutoffs and prim via custom save funcs 

483 fileobj = tar_file.extractfile('_cutoffs') 

484 cs._cutoffs = Cutoffs.read(fileobj) 

485 

486 fileobj = tar_file.extractfile('_prim') 

487 cs._prim = Atoms.read(fileobj) 

488 

489 # Orbits are stored in a separate archive 

490 cs._orbits = read_list_custom(tar_file, 'orbits', Orbit.read) 

491 cs._dropped_orbits = read_list_custom( 

492 tar_file, 'dropped_orbits', Orbit.read) 

493 

494 tar_file.close() 

495 

496 # create summary object based on CS 

497 cs.summary = ClusterSpaceData(cs) 

498 

499 # Done! 

500 return cs