Coverage for hiphive/cluster_space.py: 98%
236 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
1"""
2Contains the ClusterSpace object central to hiPhive
3"""
6import tarfile
7from collections import OrderedDict
8from copy import deepcopy
9from typing import BinaryIO, TextIO, Union
11import numpy as np
12from ase.data import chemical_symbols
14from .core.cluster_space_builder import build_cluster_space
15from .core.atoms import Atoms
16from .core.orbits import Orbit
17from .core.utilities import BiMap
18from .cluster_space_data import ClusterSpaceData
19from .input_output.logging_tools import logger
20from .input_output.read_write_files import (add_items_to_tarfile_pickle,
21 add_items_to_tarfile_custom,
22 add_list_to_tarfile_custom,
23 read_items_pickle,
24 read_list_custom)
25from .cutoffs import Cutoffs, CutoffMaximumBody, BaseClusterFilter
27from .config import Config
28logger = logger.getChild('ClusterSpace')
31class ClusterSpace:
32 """Primitive object for handling clusters and force constants of a structure.
34 Parameters
35 ----------
36 prototype_structure : ase.Atoms
37 Prototype structure; spglib will be used to find a suitable cell based
38 on this structure unless the cell is already a primitive cell.
39 cutoffs : list or Cutoffs
40 Cutoff radii for different orders starting with second order.
41 cluster_filter : ClusterFilter
42 Accepts a subclass of hiphive.filters.BaseClusterFilter to further
43 control which orbits to include.
44 config : Config object
45 A configuration object that holds information on how the cluster space
46 should be built, e.g., values for tolerances and specifications
47 regarding the handling of acoustic sum rules; if ``config`` is
48 not given then the keyword arguments that follow below can be
49 used for configuration.
50 acoustic_sum_rules : bool
51 If `True` the aucostic sum rules will be enforced by constraining the
52 parameters.
53 symprec : float
54 Numerical precision that will be used for analyzing the symmetry (this
55 parameter will be forwarded to `spglib <https://phonopy.github.io/spglib/>`_).
56 length_scale : float
57 This will be used as a normalization constant for the eigentensors.
59 Examples
60 --------
62 To instantiate a :class:`ClusterSpace` object one has to specify a
63 prototype structure and cutoff radii for each cluster order that
64 should be included. For example the following snippet will set up
65 a :class:`ClusterSpace` object for a body-centered-cubic (BCC)
66 structure including second order terms up to a distance of 5 A and
67 third order terms up to a distance of 4 A.
69 >>> from ase.build import bulk
70 >>> from hiphive import ClusterSpace
71 >>> prim = bulk('W')
72 >>> cs = ClusterSpace(prim, [5.0, 4.0])
74 """
75 # TODO: This class probably needs some more documentation
77 def __init__(self, prototype_structure, cutoffs, config=None,
78 cluster_filter=None, **kwargs):
80 if not all(prototype_structure.pbc):
81 raise ValueError('prototype_structure must have pbc.')
83 if isinstance(cutoffs, Cutoffs):
84 self._cutoffs = cutoffs
85 elif isinstance(cutoffs, list):
86 self._cutoffs = CutoffMaximumBody(cutoffs, len(cutoffs) + 1)
87 else:
88 raise TypeError('cutoffs must be a list or a Cutoffs object')
90 if config is None:
91 config = Config(**kwargs)
92 else:
93 if not isinstance(config, Config): 93 ↛ 94line 93 didn't jump to line 94 because the condition on line 93 was never true
94 raise TypeError('config kw must be of type {}'.format(Config))
95 if len(kwargs) > 0:
96 raise ValueError('use either Config or kwargs, not both')
97 self._config = config
99 if cluster_filter is None:
100 self._cluster_filter = BaseClusterFilter()
101 else:
102 self._cluster_filter = cluster_filter
104 self._atom_list = None
105 self._cluster_list = None
106 self._symmetry_dataset = None
107 self._permutations = None
108 self._prim = None
109 self._orbits = None
111 self._constraint_vectors = None
112 # TODO: How to handle the constraint matrices? Should they even be
113 # stored?
114 self._constraint_matrices = None
115 # Is this the best way or should the prim be instantiated separately?
117 build_cluster_space(self, prototype_structure)
118 self.summary = ClusterSpaceData(self)
120 @property
121 def n_dofs(self) -> int:
122 """Number of free parameters in the model.
124 If the sum rules are not enforced the number of DOFs is the same as
125 the total number of eigentensors in all orbits.
126 """
127 return self._get_n_dofs()
129 @property
130 def cutoffs(self) -> Cutoffs:
131 """ Cutoffs used for constructing the cluster space. """
132 return deepcopy(self._cutoffs)
134 @property
135 def symprec(self) -> float:
136 """ Symprec value used when constructing the cluster space. """
137 return self._config['symprec']
139 @property
140 def acoustic_sum_rules(self) -> bool:
141 """ True if acoustic sum rules are enforced. """
142 return self._config['acoustic_sum_rules']
144 @property
145 def length_scale(self) -> bool:
146 """ Normalization constant of the force constants. """
147 return self._config['length_scale']
149 @property
150 def primitive_structure(self) -> Atoms:
151 """ Structure of the lattice. """
152 return self._prim.copy()
154 @property
155 def spacegroup(self) -> str:
156 """ Space group of the lattice structure obtained from spglib. """
157 return f'{self._symmetry_dataset.international} ({self._symmetry_dataset.number})'
159 @property
160 def wyckoff_sites(self) -> list:
161 """ Wyckoff sites in the primitive cell. """
162 return self._symmetry_dataset.equivalent_atoms
164 @property
165 def rotation_matrices(self) -> list[np.ndarray]:
166 """ Symmetry elements (`3x3` matrices) representing rotations. """
167 return self._symmetry_dataset.rotations.copy()
169 @property
170 def translation_vectors(self) -> list[np.ndarray]:
171 """ Symmetry elements representing translations. """
172 # TODO: bug incoming!
173 return (self._symmetry_dataset.translations % 1).copy()
175 @property
176 def permutations(self) -> list[np.ndarray]:
177 """ Lookup for permutation references. """
178 return deepcopy(self._permutations)
180 @property
181 def atom_list(self) -> BiMap:
182 """ Atoms inside the cutoff relative to the of the center cell. """
183 return self._atom_list
185 @property
186 def cluster_list(self) -> BiMap:
187 """ Clusters possible within the cutoff. """
188 return self._cluster_list
190 @property
191 def orbits(self) -> list[Orbit]: # TODO: add __getitem__ method
192 """ Orbits associated with the lattice structure. """
193 return self._orbits
195 @property
196 def orbit_data(self) -> list[dict]:
197 """ Detailed information for each orbit, e.g., cluster radius and atom types.
198 """
199 data = []
200 p = 0
201 for orbit_index, orbit in enumerate(self.orbits):
202 d = {}
203 d['index'] = orbit_index
204 d['order'] = orbit.order
205 d['radius'] = orbit.radius
206 d['maximum_distance'] = orbit.maximum_distance
207 d['n_clusters'] = len(orbit.orientation_families)
208 d['eigentensors'] = orbit.eigentensors
209 d['n_parameters'] = len(d['eigentensors'])
211 types, wyckoff_sites = [], []
212 for atom_index in self.cluster_list[orbit.prototype_index]:
213 atom = self.atom_list[atom_index]
214 types.append(self.primitive_structure.numbers[atom.site])
215 wyckoff_sites.append(self.wyckoff_sites[atom.site])
216 d['prototype_cluster'] = self.cluster_list[orbit.prototype_index]
217 d['prototype_atom_types'] = tuple(types)
218 d['prototype_wyckoff_sites'] = tuple(wyckoff_sites)
220 d['geometrical_order'] = len(set(d['prototype_cluster']))
221 d['parameter_indices'] = np.arange(p, p + len(orbit.eigentensors))
223 p += len(orbit.eigentensors)
224 data.append(d)
226 return data
228 def get_parameter_indices(self, order: int) -> list[int]:
229 """
230 Returns a list of the parameter indices associated with the requested
231 order.
233 Parameters
234 ----------
235 order
236 Order for which to return the parameter indices.
238 Returns
239 -------
240 List of parameter indices associated with the requested order.
242 Raises
243 ------
244 ValueError
245 If the order is not included in the cluster space.
246 """
247 order = int(order)
248 if order not in self.cutoffs.orders: 248 ↛ 249line 248 didn't jump to line 249 because the condition on line 248 was never true
249 raise ValueError('Order must be in {}'.format(self.cutoffs.orders))
250 min_param = 0
251 max_param = 0
252 for orbit in self.orbits:
253 if orbit.order < order:
254 min_param += len(orbit.eigentensors)
255 max_param = min_param
256 elif orbit.order == order:
257 max_param += len(orbit.eigentensors)
258 else:
259 break
260 rows, cols = self._cvs.nonzero()
261 parameters = set()
262 for r, c in zip(rows, cols):
263 if min_param <= r < max_param:
264 parameters.add(c)
265 for r, c in zip(rows, cols):
266 if c in parameters:
267 assert min_param <= r < max_param, 'The internals are broken!'
269 return sorted(parameters)
271 def get_n_dofs_by_order(self, order: int) -> int:
272 """ Returns number of degrees of freedom for the given order.
274 Parameters
275 ----------
276 order
277 Order for which to return the number of DOFs.
279 Returns
280 -------
281 Number of degrees of freedom.
282 """
283 return len(self.get_parameter_indices(order=order))
285 def _get_n_dofs(self):
286 """ Returns the number of degrees of freedom. """
287 return self._cvs.shape[1]
289 def _map_parameters(self, parameters):
290 """ Maps irreducible parameters to the real parameters associated with
291 the eigentensors.
292 """
293 if len(parameters) != self.n_dofs: 293 ↛ 294line 293 didn't jump to line 294 because the condition on line 293 was never true
294 raise ValueError('Invalid number of parameters, please provide {} '
295 'parameters'.format(self.n_dofs))
296 return self._cvs.dot(parameters)
298 def print_tables(self):
299 """ Prints information concerning the underlying cluster space to stdout, including,
300 e.g., the number of cluster, orbits, and parameters by order and number of bodies. """
301 self.summary.print_tables()
303 def print_orbits(self):
304 """ Prints a list of all orbits. """
305 orbits = self.orbit_data
307 def str_orbit(index, orbit):
308 elements = ' '.join(chemical_symbols[n] for n in
309 orbit['prototype_atom_types'])
310 fields = OrderedDict([
311 ('index', '{:^3}'.format(index)),
312 ('order', '{:^3}'.format(orbit['order'])),
313 ('elements', '{:^18}'.format(elements)),
314 ('radius', '{:^8.4f}'.format(orbit['radius'])),
315 ('prototype', '{:^18}'.format(str(orbit['prototype_cluster']))),
316 ('clusters', '{:^4}'.format(orbit['n_clusters'])),
317 ('parameters', '{:^3}'.format(len(orbit['eigentensors']))),
318 ])
320 s = []
321 for name, value in fields.items():
322 n = max(len(name), len(value))
323 if index < 0:
324 s += ['{s:^{n}}'.format(s=name, n=n)]
325 else:
326 s += ['{s:^{n}}'.format(s=value, n=n)]
327 return ' | '.join(s)
329 # table header
330 width = max(len(str_orbit(-1, orbits[-1])), len(str_orbit(0, orbits[-1])))
331 print(' List of Orbits '.center(width, '='))
332 print(str_orbit(-1, orbits[0]))
333 print(''.center(width, '-'))
335 # table body
336 for i, orbit in enumerate(orbits):
337 print(str_orbit(i, orbit))
338 print(''.center(width, '='))
340 def __str__(self):
342 def str_order(order, header: bool = False):
343 formats = {'order': '{:2}',
344 'n_orbits': '{:5}',
345 'n_clusters': '{:5}'}
346 s = []
347 for name, value in order.items():
348 str_repr = formats[name].format(value)
349 n = max(len(name), len(str_repr))
350 if header:
351 s += ['{s:^{n}}'.format(s=name, n=n)]
352 else:
353 s += ['{s:^{n}}'.format(s=str_repr, n=n)]
354 return ' | '.join(s)
356 # collect data
357 orbits = self.orbit_data
358 orders = self.cutoffs.orders
360 order_data = {o: dict(order=o, n_orbits=0, n_clusters=0) for o in orders}
361 for orbit in orbits:
362 o = orbit['order']
363 order_data[o]['n_orbits'] += 1
364 order_data[o]['n_clusters'] += orbit['n_clusters']
366 # prototype with max order to find column width
367 max_order = max(orders)
368 prototype = order_data[max_order]
369 n = max(len(str_order(prototype)), 54)
371 # basic information
372 s = []
373 s.append(' Cluster Space '.center(n, '='))
374 data = [('Spacegroup', self.spacegroup),
375 ('symprec', self.symprec),
376 ('Sum rules', self.acoustic_sum_rules),
377 ('Length scale', self.length_scale),
378 ('Cutoffs', self.cutoffs),
379 ('Cell', self.primitive_structure.cell),
380 ('Basis', self.primitive_structure.basis),
381 ('Numbers', self.primitive_structure.numbers),
382 ('Total number of orbits', len(orbits)),
383 ('Total number of clusters',
384 sum([order_data[order]['n_clusters'] for order in orders])),
385 ('Total number of parameters', self._get_n_dofs()
386 )]
387 for field, value in data:
388 if str(value).count('\n') > 1:
389 s.append('{:26} :\n{}'.format(field, value))
390 else:
391 s.append('{:26} : {}'.format(field, value))
393 # table header
394 s.append(''.center(n, '-'))
395 s.append(str_order(prototype, header=True))
396 s.append(''.center(n, '-'))
397 for order in orders:
398 s.append(str_order(order_data[order]).rstrip())
399 s.append(''.center(n, '='))
400 return '\n'.join(s)
402 def __repr__(self):
403 s = 'ClusterSpace({!r}, {!r}, {!r})'
404 return s.format(self.primitive_structure, self.cutoffs, self._config)
406 def copy(self):
407 return deepcopy(self)
409 def write(self, fileobj: Union[str, BinaryIO, TextIO]):
410 """ Writes cluster space to file.
412 The instance is saved into a custom format based on tar-files. The
413 resulting file will be a valid tar file and can be browsed by by a tar
414 reader. The included objects are themself either pickles, npz or other
415 tars.
417 Parameters
418 ----------
419 fileobj
420 If the input is a string a tar archive will be created in the
421 current directory. Otherwise the input must be a valid file
422 like object.
423 """
424 # Create a tar archive
425 if isinstance(fileobj, str):
426 tar_file = tarfile.open(name=fileobj, mode='w')
427 else:
428 tar_file = tarfile.open(fileobj=fileobj, mode='w')
430 # Attributes in pickle format
431 pickle_attributes = ['_config',
432 '_symmetry_dataset', '_permutations',
433 '_atom_list', '_cluster_list']
434 items_pickle = dict()
435 for attribute in pickle_attributes:
436 items_pickle[attribute] = self.__getattribute__(attribute)
437 add_items_to_tarfile_pickle(tar_file, items_pickle, 'attributes')
439 # Constraint matrices and vectors in hdf5 format
440 items = dict(cvs=self._cvs)
441 add_items_to_tarfile_pickle(tar_file, items, 'constraint_vectors')
443 # Cutoffs and prim with their builtin write/read functions
444 items_custom = {'_cutoffs': self._cutoffs, '_prim': self._prim}
445 add_items_to_tarfile_custom(tar_file, items_custom)
447 # Orbits
448 add_list_to_tarfile_custom(tar_file, self._orbits, 'orbits')
449 add_list_to_tarfile_custom(tar_file, self._dropped_orbits,
450 'dropped_orbits')
452 # Done!
453 tar_file.close()
455 def read(f: Union[str, BinaryIO, TextIO]):
456 """ Reads a cluster space from file.
458 Parameters
459 ----------
460 f
461 Name of input file (`str`) or stream to load from (file object).
462 """
464 # Instantiate empty cs obj.
465 cs = ClusterSpace.__new__(ClusterSpace)
467 # Load from file on disk or file-like
468 if type(f) is str:
469 tar_file = tarfile.open(mode='r', name=f)
470 else:
471 tar_file = tarfile.open(mode='r', fileobj=f)
473 # Attributes
474 attributes = read_items_pickle(tar_file, 'attributes')
475 for name, value in attributes.items():
476 cs.__setattr__(name, value)
478 # Load the constraint matrices into their dict
479 items = read_items_pickle(tar_file, 'constraint_vectors')
480 cs._cvs = items['cvs']
482 # Cutoffs and prim via custom save funcs
483 fileobj = tar_file.extractfile('_cutoffs')
484 cs._cutoffs = Cutoffs.read(fileobj)
486 fileobj = tar_file.extractfile('_prim')
487 cs._prim = Atoms.read(fileobj)
489 # Orbits are stored in a separate archive
490 cs._orbits = read_list_custom(tar_file, 'orbits', Orbit.read)
491 cs._dropped_orbits = read_list_custom(
492 tar_file, 'dropped_orbits', Orbit.read)
494 tar_file.close()
496 # create summary object based on CS
497 cs.summary = ClusterSpaceData(cs)
499 # Done!
500 return cs