Coverage for hiphive/cluster_space.py: 98%

234 statements  

« prev     ^ index     » next       coverage.py v7.6.8, created at 2024-11-28 11:20 +0000

1""" 

2Contains the ClusterSpace object central to hiPhive 

3""" 

4 

5import numpy as np 

6import tarfile 

7 

8from ase.data import chemical_symbols 

9from collections import OrderedDict 

10from copy import deepcopy 

11 

12from .core.cluster_space_builder import build_cluster_space 

13from .core.atoms import Atoms 

14from .core.orbits import Orbit 

15from .cluster_space_data import ClusterSpaceData 

16from .input_output.logging_tools import logger 

17from .input_output.read_write_files import (add_items_to_tarfile_pickle, 

18 add_items_to_tarfile_custom, 

19 add_list_to_tarfile_custom, 

20 read_items_pickle, 

21 read_list_custom) 

22from .cutoffs import Cutoffs, CutoffMaximumBody, BaseClusterFilter 

23 

24from .config import Config 

25logger = logger.getChild('ClusterSpace') 

26 

27 

28class ClusterSpace: 

29 """Primitive object for handling clusters and force constants of a structure. 

30 

31 Parameters 

32 ---------- 

33 prototype_structure : ase.Atoms 

34 prototype structure; spglib will be used to find a suitable cell based 

35 on this structure unless the cell is already a primitive cell. 

36 cutoffs : list or Cutoffs 

37 cutoff radii for different orders starting with second order 

38 cluster_filter : ClusterFilter 

39 accepts a subclass of hiphive.filters.BaseClusterFilter to further 

40 control which orbits to include. 

41 config : Config object 

42 a configuration object that holds information on how the cluster space 

43 should be built, e.g., values for tolerances and specifications 

44 regarding the handling of acoustic sum rules; if ``config`` is 

45 not given then the keyword arguments that follow below can be 

46 used for configuration. 

47 acoustic_sum_rules : bool 

48 If True the aucostic sum rules will be enforced by constraining the 

49 parameters. 

50 symprec : float 

51 numerical precision that will be used for analyzing the symmetry (this 

52 parameter will be forwarded to 

53 `spglib <https://atztogo.github.io/spglib/>`_) 

54 length_scale : float 

55 This will be used as a normalization constant for the eigentensors 

56 

57 Examples 

58 -------- 

59 

60 To instantiate a :class:`ClusterSpace` object one has to specify a 

61 prototype structure and cutoff radii for each cluster order that 

62 should be included. For example the following snippet will set up 

63 a :class:`ClusterSpace` object for a body-centered-cubic (BCC) 

64 structure including second order terms up to a distance of 5 A and 

65 third order terms up to a distance of 4 A. 

66 

67 >>> from ase.build import bulk 

68 >>> from hiphive import ClusterSpace 

69 >>> prim = bulk('W') 

70 >>> cs = ClusterSpace(prim, [5.0, 4.0]) 

71 

72 """ 

73 # TODO: This class probably needs some more documentation 

74 

75 def __init__(self, prototype_structure, cutoffs, config=None, 

76 cluster_filter=None, **kwargs): 

77 

78 if not all(prototype_structure.pbc): 

79 raise ValueError('prototype_structure must have pbc.') 

80 

81 if isinstance(cutoffs, Cutoffs): 

82 self._cutoffs = cutoffs 

83 elif isinstance(cutoffs, list): 

84 self._cutoffs = CutoffMaximumBody(cutoffs, len(cutoffs) + 1) 

85 else: 

86 raise TypeError('cutoffs must be a list or a Cutoffs object') 

87 

88 if config is None: 

89 config = Config(**kwargs) 

90 else: 

91 if not isinstance(config, Config): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true

92 raise TypeError('config kw must be of type {}'.format(Config)) 

93 if len(kwargs) > 0: 

94 raise ValueError('use either Config or kwargs, not both') 

95 self._config = config 

96 

97 if cluster_filter is None: 

98 self._cluster_filter = BaseClusterFilter() 

99 else: 

100 self._cluster_filter = cluster_filter 

101 

102 self._atom_list = None 

103 self._cluster_list = None 

104 self._symmetry_dataset = None 

105 self._permutations = None 

106 self._prim = None 

107 self._orbits = None 

108 

109 self._constraint_vectors = None 

110 # TODO: How to handle the constraint matrices? Should they even be 

111 # stored? 

112 self._constraint_matrices = None 

113 # Is this the best way or should the prim be instantiated separately? 

114 

115 build_cluster_space(self, prototype_structure) 

116 self.summary = ClusterSpaceData(self) 

117 

118 # TODO: Should everything here be properties? deepcopy/ref etc.? 

119 @property 

120 def n_dofs(self): 

121 """int : number of free parameters in the model 

122 

123 If the sum rules are not enforced the number of DOFs is the same as 

124 the total number of eigentensors in all orbits. 

125 """ 

126 return self._get_n_dofs() 

127 

128 @property 

129 def cutoffs(self): 

130 """ Cutoffs : cutoffs used for constructing the cluster space """ 

131 return deepcopy(self._cutoffs) 

132 

133 @property 

134 def symprec(self): 

135 """ float : symprec value used when constructing the cluster space """ 

136 return self._config['symprec'] 

137 

138 @property 

139 def acoustic_sum_rules(self): 

140 """ bool : True if acoustic sum rules are enforced """ 

141 return self._config['acoustic_sum_rules'] 

142 

143 @property 

144 def length_scale(self): 

145 """ float : normalization constant of the force constants """ 

146 return self._config['length_scale'] 

147 

148 @property 

149 def primitive_structure(self): 

150 """ ase.Atoms : structure of the lattice """ 

151 return self._prim.copy() 

152 

153 @property 

154 def spacegroup(self): 

155 """ str : space group of the lattice structure obtained from spglib """ 

156 return f'{self._symmetry_dataset.international} ({self._symmetry_dataset.number})' 

157 

158 @property 

159 def wyckoff_sites(self): 

160 """ list : wyckoff sites in the primitive cell """ 

161 return self._symmetry_dataset.equivalent_atoms 

162 

163 @property 

164 def rotation_matrices(self): 

165 """ list(numpy.ndarray) : symmetry elements (`3x3` matrices) 

166 representing rotations """ 

167 return self._symmetry_dataset.rotations.copy() 

168 

169 @property 

170 def translation_vectors(self): 

171 """ list(numpy.ndarray) : symmetry elements representing 

172 translations """ 

173 # TODO: bug incoming! 

174 return (self._symmetry_dataset.translations % 1).copy() 

175 

176 @property 

177 def permutations(self): 

178 """ list(numpy.ndarray) : lookup for permutation references """ 

179 return deepcopy(self._permutations) 

180 

181 @property 

182 def atom_list(self): 

183 """ BiMap : atoms inside the cutoff relative to the of the center cell 

184 """ 

185 return self._atom_list 

186 

187 @property 

188 def cluster_list(self): 

189 """ BiMap : clusters possible within the cutoff """ 

190 return self._cluster_list 

191 

192 @property 

193 def orbits(self): # TODO: add __getitem__ method 

194 """ list(Orbit) : orbits associated with the lattice structure. """ 

195 return self._orbits 

196 

197 @property 

198 def orbit_data(self): 

199 """ list(dict) : detailed information for each orbit, e.g., cluster 

200 radius and atom types. 

201 """ 

202 data = [] 

203 p = 0 

204 for orbit_index, orbit in enumerate(self.orbits): 

205 d = {} 

206 d['index'] = orbit_index 

207 d['order'] = orbit.order 

208 d['radius'] = orbit.radius 

209 d['maximum_distance'] = orbit.maximum_distance 

210 d['n_clusters'] = len(orbit.orientation_families) 

211 d['eigentensors'] = orbit.eigentensors 

212 d['n_parameters'] = len(d['eigentensors']) 

213 

214 types, wyckoff_sites = [], [] 

215 for atom_index in self.cluster_list[orbit.prototype_index]: 

216 atom = self.atom_list[atom_index] 

217 types.append(self.primitive_structure.numbers[atom.site]) 

218 wyckoff_sites.append(self.wyckoff_sites[atom.site]) 

219 d['prototype_cluster'] = self.cluster_list[orbit.prototype_index] 

220 d['prototype_atom_types'] = tuple(types) 

221 d['prototype_wyckoff_sites'] = tuple(wyckoff_sites) 

222 

223 d['geometrical_order'] = len(set(d['prototype_cluster'])) 

224 d['parameter_indices'] = np.arange(p, p + len(orbit.eigentensors)) 

225 

226 p += len(orbit.eigentensors) 

227 data.append(d) 

228 

229 return data 

230 

231 def get_parameter_indices(self, order): 

232 """ 

233 Returns a list of the parameter indices associated with the requested 

234 order. 

235 

236 Parameters 

237 ---------- 

238 order : int 

239 order for which to return the parameter indices 

240 

241 Returns 

242 ------- 

243 list(int) 

244 list of parameter indices associated with the requested order 

245 

246 Raises 

247 ------ 

248 ValueError 

249 if the order is not included in the cluster space 

250 """ 

251 order = int(order) 

252 if order not in self.cutoffs.orders: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 raise ValueError('Order must be in {}'.format(self.cutoffs.orders)) 

254 min_param = 0 

255 max_param = 0 

256 for orbit in self.orbits: 

257 if orbit.order < order: 

258 min_param += len(orbit.eigentensors) 

259 max_param = min_param 

260 elif orbit.order == order: 

261 max_param += len(orbit.eigentensors) 

262 else: 

263 break 

264 rows, cols = self._cvs.nonzero() 

265 parameters = set() 

266 for r, c in zip(rows, cols): 

267 if min_param <= r < max_param: 

268 parameters.add(c) 

269 for r, c in zip(rows, cols): 

270 if c in parameters: 

271 assert min_param <= r < max_param, 'The internals are broken!' 

272 

273 return sorted(parameters) 

274 

275 def get_n_dofs_by_order(self, order): 

276 """ Returns number of degrees of freedom for the given order. 

277 

278 Parameters 

279 ---------- 

280 order : int 

281 order for which to return the number of dofs 

282 

283 Returns 

284 ------- 

285 int 

286 number of degrees of freedom 

287 """ 

288 return len(self.get_parameter_indices(order=order)) 

289 

290 def _get_n_dofs(self): 

291 """ Returns the number of degrees of freedom. """ 

292 return self._cvs.shape[1] 

293 

294 def _map_parameters(self, parameters): 

295 """ Maps irreducible parameters to the real parameters associated with 

296 the eigentensors. 

297 """ 

298 if len(parameters) != self.n_dofs: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true

299 raise ValueError('Invalid number of parameters, please provide {} ' 

300 'parameters'.format(self.n_dofs)) 

301 return self._cvs.dot(parameters) 

302 

303 def print_tables(self): 

304 """ Prints information concerning the underlying cluster space to stdout,including, 

305 e.g., the number of cluster, orbits, and parameters by order and number of bodies. """ 

306 self.summary.print_tables() 

307 

308 def print_orbits(self): 

309 """ Prints a list of all orbits. """ 

310 orbits = self.orbit_data 

311 

312 def str_orbit(index, orbit): 

313 elements = ' '.join(chemical_symbols[n] for n in 

314 orbit['prototype_atom_types']) 

315 fields = OrderedDict([ 

316 ('index', '{:^3}'.format(index)), 

317 ('order', '{:^3}'.format(orbit['order'])), 

318 ('elements', '{:^18}'.format(elements)), 

319 ('radius', '{:^8.4f}'.format(orbit['radius'])), 

320 ('prototype', '{:^18}'.format(str(orbit['prototype_cluster']))), 

321 ('clusters', '{:^4}'.format(orbit['n_clusters'])), 

322 ('parameters', '{:^3}'.format(len(orbit['eigentensors']))), 

323 ]) 

324 

325 s = [] 

326 for name, value in fields.items(): 

327 n = max(len(name), len(value)) 

328 if index < 0: 

329 s += ['{s:^{n}}'.format(s=name, n=n)] 

330 else: 

331 s += ['{s:^{n}}'.format(s=value, n=n)] 

332 return ' | '.join(s) 

333 

334 # table header 

335 width = max(len(str_orbit(-1, orbits[-1])), len(str_orbit(0, orbits[-1]))) 

336 print(' List of Orbits '.center(width, '=')) 

337 print(str_orbit(-1, orbits[0])) 

338 print(''.center(width, '-')) 

339 

340 # table body 

341 for i, orbit in enumerate(orbits): 

342 print(str_orbit(i, orbit)) 

343 print(''.center(width, '=')) 

344 

345 def __str__(self): 

346 

347 def str_order(order, header=False): 

348 formats = {'order': '{:2}', 

349 'n_orbits': '{:5}', 

350 'n_clusters': '{:5}'} 

351 s = [] 

352 for name, value in order.items(): 

353 str_repr = formats[name].format(value) 

354 n = max(len(name), len(str_repr)) 

355 if header: 

356 s += ['{s:^{n}}'.format(s=name, n=n)] 

357 else: 

358 s += ['{s:^{n}}'.format(s=str_repr, n=n)] 

359 return ' | '.join(s) 

360 

361 # collect data 

362 orbits = self.orbit_data 

363 orders = self.cutoffs.orders 

364 

365 order_data = {o: dict(order=o, n_orbits=0, n_clusters=0) for o in orders} 

366 for orbit in orbits: 

367 o = orbit['order'] 

368 order_data[o]['n_orbits'] += 1 

369 order_data[o]['n_clusters'] += orbit['n_clusters'] 

370 

371 # prototype with max order to find column width 

372 max_order = max(orders) 

373 prototype = order_data[max_order] 

374 n = max(len(str_order(prototype)), 54) 

375 

376 # basic information 

377 s = [] 

378 s.append(' Cluster Space '.center(n, '=')) 

379 data = [('Spacegroup', self.spacegroup), 

380 ('symprec', self.symprec), 

381 ('Sum rules', self.acoustic_sum_rules), 

382 ('Length scale', self.length_scale), 

383 ('Cutoffs', self.cutoffs), 

384 ('Cell', self.primitive_structure.cell), 

385 ('Basis', self.primitive_structure.basis), 

386 ('Numbers', self.primitive_structure.numbers), 

387 ('Total number of orbits', len(orbits)), 

388 ('Total number of clusters', 

389 sum([order_data[order]['n_clusters'] for order in orders])), 

390 ('Total number of parameters', self._get_n_dofs() 

391 )] 

392 for field, value in data: 

393 if str(value).count('\n') > 1: 

394 s.append('{:26} :\n{}'.format(field, value)) 

395 else: 

396 s.append('{:26} : {}'.format(field, value)) 

397 

398 # table header 

399 s.append(''.center(n, '-')) 

400 s.append(str_order(prototype, header=True)) 

401 s.append(''.center(n, '-')) 

402 for order in orders: 

403 s.append(str_order(order_data[order]).rstrip()) 

404 s.append(''.center(n, '=')) 

405 return '\n'.join(s) 

406 

407 def __repr__(self): 

408 s = 'ClusterSpace({!r}, {!r}, {!r})' 

409 return s.format(self.primitive_structure, self.cutoffs, self._config) 

410 

411 def copy(self): 

412 return deepcopy(self) 

413 

414 def write(self, fileobj): 

415 """ Writes cluster space to file. 

416 

417 The instance is saved into a custom format based on tar-files. The 

418 resulting file will be a valid tar file and can be browsed by by a tar 

419 reader. The included objects are themself either pickles, npz or other 

420 tars. 

421 

422 Parameters 

423 ---------- 

424 fileobj : str or file-like object 

425 If the input is a string a tar archive will be created in the 

426 current directory. Otherwise the input must be a valid file 

427 like object. 

428 """ 

429 # Create a tar archive 

430 if isinstance(fileobj, str): 

431 tar_file = tarfile.open(name=fileobj, mode='w') 

432 else: 

433 tar_file = tarfile.open(fileobj=fileobj, mode='w') 

434 

435 # Attributes in pickle format 

436 pickle_attributes = ['_config', 

437 '_symmetry_dataset', '_permutations', 

438 '_atom_list', '_cluster_list'] 

439 items_pickle = dict() 

440 for attribute in pickle_attributes: 

441 items_pickle[attribute] = self.__getattribute__(attribute) 

442 add_items_to_tarfile_pickle(tar_file, items_pickle, 'attributes') 

443 

444 # Constraint matrices and vectors in hdf5 format 

445 items = dict(cvs=self._cvs) 

446 add_items_to_tarfile_pickle(tar_file, items, 'constraint_vectors') 

447 

448 # Cutoffs and prim with their builtin write/read functions 

449 items_custom = {'_cutoffs': self._cutoffs, '_prim': self._prim} 

450 add_items_to_tarfile_custom(tar_file, items_custom) 

451 

452 # Orbits 

453 add_list_to_tarfile_custom(tar_file, self._orbits, 'orbits') 

454 add_list_to_tarfile_custom(tar_file, self._dropped_orbits, 

455 'dropped_orbits') 

456 

457 # Done! 

458 tar_file.close() 

459 

460 def read(f): 

461 """ Reads a cluster space from file. 

462 

463 Parameters 

464 ---------- 

465 f : str or file object 

466 name of input file (str) or stream to load from (file object) 

467 """ 

468 

469 # Instantiate empty cs obj. 

470 cs = ClusterSpace.__new__(ClusterSpace) 

471 

472 # Load from file on disk or file-like 

473 if type(f) is str: 

474 tar_file = tarfile.open(mode='r', name=f) 

475 else: 

476 tar_file = tarfile.open(mode='r', fileobj=f) 

477 

478 # Attributes 

479 attributes = read_items_pickle(tar_file, 'attributes') 

480 for name, value in attributes.items(): 

481 cs.__setattr__(name, value) 

482 

483 # Load the constraint matrices into their dict 

484 items = read_items_pickle(tar_file, 'constraint_vectors') 

485 cs._cvs = items['cvs'] 

486 

487 # Cutoffs and prim via custom save funcs 

488 fileobj = tar_file.extractfile('_cutoffs') 

489 cs._cutoffs = Cutoffs.read(fileobj) 

490 

491 fileobj = tar_file.extractfile('_prim') 

492 cs._prim = Atoms.read(fileobj) 

493 

494 # Orbits are stored in a separate archive 

495 cs._orbits = read_list_custom(tar_file, 'orbits', Orbit.read) 

496 cs._dropped_orbits = read_list_custom( 

497 tar_file, 'dropped_orbits', Orbit.read) 

498 

499 tar_file.close() 

500 

501 # create summary object based on CS 

502 cs.summary = ClusterSpaceData(cs) 

503 

504 # Done! 

505 return cs