Coverage for hiphive/cluster_space.py: 98%
234 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1"""
2Contains the ClusterSpace object central to hiPhive
3"""
5import numpy as np
6import tarfile
8from ase.data import chemical_symbols
9from collections import OrderedDict
10from copy import deepcopy
12from .core.cluster_space_builder import build_cluster_space
13from .core.atoms import Atoms
14from .core.orbits import Orbit
15from .cluster_space_data import ClusterSpaceData
16from .input_output.logging_tools import logger
17from .input_output.read_write_files import (add_items_to_tarfile_pickle,
18 add_items_to_tarfile_custom,
19 add_list_to_tarfile_custom,
20 read_items_pickle,
21 read_list_custom)
22from .cutoffs import Cutoffs, CutoffMaximumBody, BaseClusterFilter
24from .config import Config
25logger = logger.getChild('ClusterSpace')
28class ClusterSpace:
29 """Primitive object for handling clusters and force constants of a structure.
31 Parameters
32 ----------
33 prototype_structure : ase.Atoms
34 prototype structure; spglib will be used to find a suitable cell based
35 on this structure unless the cell is already a primitive cell.
36 cutoffs : list or Cutoffs
37 cutoff radii for different orders starting with second order
38 cluster_filter : ClusterFilter
39 accepts a subclass of hiphive.filters.BaseClusterFilter to further
40 control which orbits to include.
41 config : Config object
42 a configuration object that holds information on how the cluster space
43 should be built, e.g., values for tolerances and specifications
44 regarding the handling of acoustic sum rules; if ``config`` is
45 not given then the keyword arguments that follow below can be
46 used for configuration.
47 acoustic_sum_rules : bool
48 If True the aucostic sum rules will be enforced by constraining the
49 parameters.
50 symprec : float
51 numerical precision that will be used for analyzing the symmetry (this
52 parameter will be forwarded to
53 `spglib <https://atztogo.github.io/spglib/>`_)
54 length_scale : float
55 This will be used as a normalization constant for the eigentensors
57 Examples
58 --------
60 To instantiate a :class:`ClusterSpace` object one has to specify a
61 prototype structure and cutoff radii for each cluster order that
62 should be included. For example the following snippet will set up
63 a :class:`ClusterSpace` object for a body-centered-cubic (BCC)
64 structure including second order terms up to a distance of 5 A and
65 third order terms up to a distance of 4 A.
67 >>> from ase.build import bulk
68 >>> from hiphive import ClusterSpace
69 >>> prim = bulk('W')
70 >>> cs = ClusterSpace(prim, [5.0, 4.0])
72 """
73 # TODO: This class probably needs some more documentation
75 def __init__(self, prototype_structure, cutoffs, config=None,
76 cluster_filter=None, **kwargs):
78 if not all(prototype_structure.pbc):
79 raise ValueError('prototype_structure must have pbc.')
81 if isinstance(cutoffs, Cutoffs):
82 self._cutoffs = cutoffs
83 elif isinstance(cutoffs, list):
84 self._cutoffs = CutoffMaximumBody(cutoffs, len(cutoffs) + 1)
85 else:
86 raise TypeError('cutoffs must be a list or a Cutoffs object')
88 if config is None:
89 config = Config(**kwargs)
90 else:
91 if not isinstance(config, Config): 91 ↛ 92line 91 didn't jump to line 92 because the condition on line 91 was never true
92 raise TypeError('config kw must be of type {}'.format(Config))
93 if len(kwargs) > 0:
94 raise ValueError('use either Config or kwargs, not both')
95 self._config = config
97 if cluster_filter is None:
98 self._cluster_filter = BaseClusterFilter()
99 else:
100 self._cluster_filter = cluster_filter
102 self._atom_list = None
103 self._cluster_list = None
104 self._symmetry_dataset = None
105 self._permutations = None
106 self._prim = None
107 self._orbits = None
109 self._constraint_vectors = None
110 # TODO: How to handle the constraint matrices? Should they even be
111 # stored?
112 self._constraint_matrices = None
113 # Is this the best way or should the prim be instantiated separately?
115 build_cluster_space(self, prototype_structure)
116 self.summary = ClusterSpaceData(self)
118 # TODO: Should everything here be properties? deepcopy/ref etc.?
119 @property
120 def n_dofs(self):
121 """int : number of free parameters in the model
123 If the sum rules are not enforced the number of DOFs is the same as
124 the total number of eigentensors in all orbits.
125 """
126 return self._get_n_dofs()
128 @property
129 def cutoffs(self):
130 """ Cutoffs : cutoffs used for constructing the cluster space """
131 return deepcopy(self._cutoffs)
133 @property
134 def symprec(self):
135 """ float : symprec value used when constructing the cluster space """
136 return self._config['symprec']
138 @property
139 def acoustic_sum_rules(self):
140 """ bool : True if acoustic sum rules are enforced """
141 return self._config['acoustic_sum_rules']
143 @property
144 def length_scale(self):
145 """ float : normalization constant of the force constants """
146 return self._config['length_scale']
148 @property
149 def primitive_structure(self):
150 """ ase.Atoms : structure of the lattice """
151 return self._prim.copy()
153 @property
154 def spacegroup(self):
155 """ str : space group of the lattice structure obtained from spglib """
156 return f'{self._symmetry_dataset.international} ({self._symmetry_dataset.number})'
158 @property
159 def wyckoff_sites(self):
160 """ list : wyckoff sites in the primitive cell """
161 return self._symmetry_dataset.equivalent_atoms
163 @property
164 def rotation_matrices(self):
165 """ list(numpy.ndarray) : symmetry elements (`3x3` matrices)
166 representing rotations """
167 return self._symmetry_dataset.rotations.copy()
169 @property
170 def translation_vectors(self):
171 """ list(numpy.ndarray) : symmetry elements representing
172 translations """
173 # TODO: bug incoming!
174 return (self._symmetry_dataset.translations % 1).copy()
176 @property
177 def permutations(self):
178 """ list(numpy.ndarray) : lookup for permutation references """
179 return deepcopy(self._permutations)
181 @property
182 def atom_list(self):
183 """ BiMap : atoms inside the cutoff relative to the of the center cell
184 """
185 return self._atom_list
187 @property
188 def cluster_list(self):
189 """ BiMap : clusters possible within the cutoff """
190 return self._cluster_list
192 @property
193 def orbits(self): # TODO: add __getitem__ method
194 """ list(Orbit) : orbits associated with the lattice structure. """
195 return self._orbits
197 @property
198 def orbit_data(self):
199 """ list(dict) : detailed information for each orbit, e.g., cluster
200 radius and atom types.
201 """
202 data = []
203 p = 0
204 for orbit_index, orbit in enumerate(self.orbits):
205 d = {}
206 d['index'] = orbit_index
207 d['order'] = orbit.order
208 d['radius'] = orbit.radius
209 d['maximum_distance'] = orbit.maximum_distance
210 d['n_clusters'] = len(orbit.orientation_families)
211 d['eigentensors'] = orbit.eigentensors
212 d['n_parameters'] = len(d['eigentensors'])
214 types, wyckoff_sites = [], []
215 for atom_index in self.cluster_list[orbit.prototype_index]:
216 atom = self.atom_list[atom_index]
217 types.append(self.primitive_structure.numbers[atom.site])
218 wyckoff_sites.append(self.wyckoff_sites[atom.site])
219 d['prototype_cluster'] = self.cluster_list[orbit.prototype_index]
220 d['prototype_atom_types'] = tuple(types)
221 d['prototype_wyckoff_sites'] = tuple(wyckoff_sites)
223 d['geometrical_order'] = len(set(d['prototype_cluster']))
224 d['parameter_indices'] = np.arange(p, p + len(orbit.eigentensors))
226 p += len(orbit.eigentensors)
227 data.append(d)
229 return data
231 def get_parameter_indices(self, order):
232 """
233 Returns a list of the parameter indices associated with the requested
234 order.
236 Parameters
237 ----------
238 order : int
239 order for which to return the parameter indices
241 Returns
242 -------
243 list(int)
244 list of parameter indices associated with the requested order
246 Raises
247 ------
248 ValueError
249 if the order is not included in the cluster space
250 """
251 order = int(order)
252 if order not in self.cutoffs.orders: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 raise ValueError('Order must be in {}'.format(self.cutoffs.orders))
254 min_param = 0
255 max_param = 0
256 for orbit in self.orbits:
257 if orbit.order < order:
258 min_param += len(orbit.eigentensors)
259 max_param = min_param
260 elif orbit.order == order:
261 max_param += len(orbit.eigentensors)
262 else:
263 break
264 rows, cols = self._cvs.nonzero()
265 parameters = set()
266 for r, c in zip(rows, cols):
267 if min_param <= r < max_param:
268 parameters.add(c)
269 for r, c in zip(rows, cols):
270 if c in parameters:
271 assert min_param <= r < max_param, 'The internals are broken!'
273 return sorted(parameters)
275 def get_n_dofs_by_order(self, order):
276 """ Returns number of degrees of freedom for the given order.
278 Parameters
279 ----------
280 order : int
281 order for which to return the number of dofs
283 Returns
284 -------
285 int
286 number of degrees of freedom
287 """
288 return len(self.get_parameter_indices(order=order))
290 def _get_n_dofs(self):
291 """ Returns the number of degrees of freedom. """
292 return self._cvs.shape[1]
294 def _map_parameters(self, parameters):
295 """ Maps irreducible parameters to the real parameters associated with
296 the eigentensors.
297 """
298 if len(parameters) != self.n_dofs: 298 ↛ 299line 298 didn't jump to line 299 because the condition on line 298 was never true
299 raise ValueError('Invalid number of parameters, please provide {} '
300 'parameters'.format(self.n_dofs))
301 return self._cvs.dot(parameters)
303 def print_tables(self):
304 """ Prints information concerning the underlying cluster space to stdout,including,
305 e.g., the number of cluster, orbits, and parameters by order and number of bodies. """
306 self.summary.print_tables()
308 def print_orbits(self):
309 """ Prints a list of all orbits. """
310 orbits = self.orbit_data
312 def str_orbit(index, orbit):
313 elements = ' '.join(chemical_symbols[n] for n in
314 orbit['prototype_atom_types'])
315 fields = OrderedDict([
316 ('index', '{:^3}'.format(index)),
317 ('order', '{:^3}'.format(orbit['order'])),
318 ('elements', '{:^18}'.format(elements)),
319 ('radius', '{:^8.4f}'.format(orbit['radius'])),
320 ('prototype', '{:^18}'.format(str(orbit['prototype_cluster']))),
321 ('clusters', '{:^4}'.format(orbit['n_clusters'])),
322 ('parameters', '{:^3}'.format(len(orbit['eigentensors']))),
323 ])
325 s = []
326 for name, value in fields.items():
327 n = max(len(name), len(value))
328 if index < 0:
329 s += ['{s:^{n}}'.format(s=name, n=n)]
330 else:
331 s += ['{s:^{n}}'.format(s=value, n=n)]
332 return ' | '.join(s)
334 # table header
335 width = max(len(str_orbit(-1, orbits[-1])), len(str_orbit(0, orbits[-1])))
336 print(' List of Orbits '.center(width, '='))
337 print(str_orbit(-1, orbits[0]))
338 print(''.center(width, '-'))
340 # table body
341 for i, orbit in enumerate(orbits):
342 print(str_orbit(i, orbit))
343 print(''.center(width, '='))
345 def __str__(self):
347 def str_order(order, header=False):
348 formats = {'order': '{:2}',
349 'n_orbits': '{:5}',
350 'n_clusters': '{:5}'}
351 s = []
352 for name, value in order.items():
353 str_repr = formats[name].format(value)
354 n = max(len(name), len(str_repr))
355 if header:
356 s += ['{s:^{n}}'.format(s=name, n=n)]
357 else:
358 s += ['{s:^{n}}'.format(s=str_repr, n=n)]
359 return ' | '.join(s)
361 # collect data
362 orbits = self.orbit_data
363 orders = self.cutoffs.orders
365 order_data = {o: dict(order=o, n_orbits=0, n_clusters=0) for o in orders}
366 for orbit in orbits:
367 o = orbit['order']
368 order_data[o]['n_orbits'] += 1
369 order_data[o]['n_clusters'] += orbit['n_clusters']
371 # prototype with max order to find column width
372 max_order = max(orders)
373 prototype = order_data[max_order]
374 n = max(len(str_order(prototype)), 54)
376 # basic information
377 s = []
378 s.append(' Cluster Space '.center(n, '='))
379 data = [('Spacegroup', self.spacegroup),
380 ('symprec', self.symprec),
381 ('Sum rules', self.acoustic_sum_rules),
382 ('Length scale', self.length_scale),
383 ('Cutoffs', self.cutoffs),
384 ('Cell', self.primitive_structure.cell),
385 ('Basis', self.primitive_structure.basis),
386 ('Numbers', self.primitive_structure.numbers),
387 ('Total number of orbits', len(orbits)),
388 ('Total number of clusters',
389 sum([order_data[order]['n_clusters'] for order in orders])),
390 ('Total number of parameters', self._get_n_dofs()
391 )]
392 for field, value in data:
393 if str(value).count('\n') > 1:
394 s.append('{:26} :\n{}'.format(field, value))
395 else:
396 s.append('{:26} : {}'.format(field, value))
398 # table header
399 s.append(''.center(n, '-'))
400 s.append(str_order(prototype, header=True))
401 s.append(''.center(n, '-'))
402 for order in orders:
403 s.append(str_order(order_data[order]).rstrip())
404 s.append(''.center(n, '='))
405 return '\n'.join(s)
407 def __repr__(self):
408 s = 'ClusterSpace({!r}, {!r}, {!r})'
409 return s.format(self.primitive_structure, self.cutoffs, self._config)
411 def copy(self):
412 return deepcopy(self)
414 def write(self, fileobj):
415 """ Writes cluster space to file.
417 The instance is saved into a custom format based on tar-files. The
418 resulting file will be a valid tar file and can be browsed by by a tar
419 reader. The included objects are themself either pickles, npz or other
420 tars.
422 Parameters
423 ----------
424 fileobj : str or file-like object
425 If the input is a string a tar archive will be created in the
426 current directory. Otherwise the input must be a valid file
427 like object.
428 """
429 # Create a tar archive
430 if isinstance(fileobj, str):
431 tar_file = tarfile.open(name=fileobj, mode='w')
432 else:
433 tar_file = tarfile.open(fileobj=fileobj, mode='w')
435 # Attributes in pickle format
436 pickle_attributes = ['_config',
437 '_symmetry_dataset', '_permutations',
438 '_atom_list', '_cluster_list']
439 items_pickle = dict()
440 for attribute in pickle_attributes:
441 items_pickle[attribute] = self.__getattribute__(attribute)
442 add_items_to_tarfile_pickle(tar_file, items_pickle, 'attributes')
444 # Constraint matrices and vectors in hdf5 format
445 items = dict(cvs=self._cvs)
446 add_items_to_tarfile_pickle(tar_file, items, 'constraint_vectors')
448 # Cutoffs and prim with their builtin write/read functions
449 items_custom = {'_cutoffs': self._cutoffs, '_prim': self._prim}
450 add_items_to_tarfile_custom(tar_file, items_custom)
452 # Orbits
453 add_list_to_tarfile_custom(tar_file, self._orbits, 'orbits')
454 add_list_to_tarfile_custom(tar_file, self._dropped_orbits,
455 'dropped_orbits')
457 # Done!
458 tar_file.close()
460 def read(f):
461 """ Reads a cluster space from file.
463 Parameters
464 ----------
465 f : str or file object
466 name of input file (str) or stream to load from (file object)
467 """
469 # Instantiate empty cs obj.
470 cs = ClusterSpace.__new__(ClusterSpace)
472 # Load from file on disk or file-like
473 if type(f) is str:
474 tar_file = tarfile.open(mode='r', name=f)
475 else:
476 tar_file = tarfile.open(mode='r', fileobj=f)
478 # Attributes
479 attributes = read_items_pickle(tar_file, 'attributes')
480 for name, value in attributes.items():
481 cs.__setattr__(name, value)
483 # Load the constraint matrices into their dict
484 items = read_items_pickle(tar_file, 'constraint_vectors')
485 cs._cvs = items['cvs']
487 # Cutoffs and prim via custom save funcs
488 fileobj = tar_file.extractfile('_cutoffs')
489 cs._cutoffs = Cutoffs.read(fileobj)
491 fileobj = tar_file.extractfile('_prim')
492 cs._prim = Atoms.read(fileobj)
494 # Orbits are stored in a separate archive
495 cs._orbits = read_list_custom(tar_file, 'orbits', Orbit.read)
496 cs._dropped_orbits = read_list_custom(
497 tar_file, 'dropped_orbits', Orbit.read)
499 tar_file.close()
501 # create summary object based on CS
502 cs.summary = ClusterSpaceData(cs)
504 # Done!
505 return cs