import itertools
import numpy as np
import spglib as spg
from collections import Counter
from ase.neighborlist import NeighborList

from .atoms import Atom, Atoms
from .clusters import get_clusters
from .orbits import get_orbits
from ..input_output.logging_tools import logger
from .eigentensors import create_eigentensors as _create_eigentensors
from .tensors import rotate_tensor_precalc, rotation_tensor_as_matrix
from .translational_constraints import create_constraint_map as _create_constraint_map
from .utilities import BiMap, ase_atoms_to_spglib_tuple

[docs]def build_cluster_space(cluster_space, prototype_structure): """ The permutation list is an indexed fast lookup table for permutation vectors. """ logger.debug('Populate permutation list') _create_permutations(cluster_space) """ The primitive cell is calcualted by spglib and contains the cell metric, basis and atomic numbers. After this the prototype structure is disposed. """ logger.debug('Get primitive cell') _create_primitive_cell(cluster_space, prototype_structure) """ The symmetries are calculated by spglib and the main information is the rotation matrices and translation vectors for each symmetry in scaled coordinates. """ logger.debug('Get symmetries') _create_symmetry_dataset(cluster_space) """ The neigbor atoms to the center cell are stored in an indexed list. The list contains all atoms within the maximum cutoff specified. """ logger.debug('Find neighbors') _create_atom_list(cluster_space) """ Clusters are generated as combinations of the indices of the atoms in the atom_list. If the cluster is to be included or not depends on the specification of valid clusters in the cutoffs object. Often all atoms must be within some distance from each other which may depend on both expansion order and number of atoms in the cluster """ logger.debug('Starting generating clusters') _create_cluster_list(cluster_space) logger.debug('Finished generating clusters') """ The clusters are categorized into orbits and the eigensymmetries are stored in each orbit. """ logger.debug('Starting categorizing clusters') _create_orbits(cluster_space) logger.debug('Finished categorizing clusters') """ The eigensymmetries from the previous step is used to generate valid eigentensors """ logger.debug('Starting finding eigentensors') _create_eigentensors(cluster_space) logger.debug('Finished finding eigentensors') """ If some orbits can't have a force constant due to symmetry they are dropped from the _orbits attribute but kept in a separate list named _dropped_orbits """ logger.debug('Dropping orbits...') _drop_orbits(cluster_space) """ Each orientation family gets populated with its rotated version of the orbits eigentensors. """ logger.debug('Rotating eigentensors into ofs...') _populate_ofs_with_ets(cluster_space) """ The matrix describing the mapping which preserves the global symmetries is created (translational and rotational symmetry). """ logger.debug('Constructing constraint map') _create_constraint_map(cluster_space)'Constraints:')' Acoustic: {}'.format(cluster_space.acoustic_sum_rules)) ndofs_by_order = {o: cluster_space.get_n_dofs_by_order(o) for o in cluster_space.cutoffs.orders}' Number of degrees of freedom: {}'.format(ndofs_by_order)) for order, count in ndofs_by_order.items(): if count == 0: logger.warning(' Warning: No degrees of freedom exists for order {}'.format(order))' Total number of degrees of freedom: {}'.format(cluster_space.n_dofs)) """ The eigentensors are rescaled depending on order to create a more well behaved system for fitting. NOTE! In here the constratints are rescaled too """ logger.debug('Rescale eigentensors') _rescale_eigentensors(cluster_space) logger.debug('Normalize constraints') _normalize_constraint_vectors(cluster_space) logger.debug('Rotate tensors to Carteesian coordinates') _rotate_eigentensors(cluster_space)
# TODO: Actual input could be just the maximum order
# TODO: No side effects, returns only the permutation BiMap
def _create_permutations(cs):
    orders = cs.cutoffs.orders
    permutations = BiMap()
    for order in orders:
        for permutation in itertools.permutations(range(order)):
            permutations.append(permutation)
    cs._permutations = permutations


def _create_primitive_cell(cs, prototype_structure): structure_as_tuple = ase_atoms_to_spglib_tuple(prototype_structure)
    spgPrim = spg.standardize_cell(structure_as_tuple,
                                    no_idealize=True,
                                    to_primitive=True,
                                    symprec=cs.symprec)

    numbers_match = sorted(spgPrim[2]) == sorted(prototype_structure.numbers)
    spg_cell_volume = np.abs(np.linalg.det(spgPrim[0].T))
    prototype_cell_volume = np.abs(np.linalg.det(prototype_structure.cell.T))
    cell_volume_match = np.isclose(spg_cell_volume,
                                    prototype_cell_volume,
                                    atol=cs.symprec,
                                    rtol=0)

    if numbers_match and cell_volume_match:
        prim = Atoms(prototype_structure)
        prim.wrap()
    else:
        basis = spgPrim[1]
        if np.any(basis > (1 - cs.symprec)):
            logger.debug('Found basis close to 1:\n {}'.format(str(basis)))
            basis = basis.round(8) % 1
            logger.debug('Wrapping to:\n {}'.format(str(basis)))
        prim = Atoms(cell=spgPrim[0],
                     scaled_positions=basis,
                     numbers=spgPrim[2],
                     pbc=True)

    # log primitive cell information'Primitive cell:')
    '  Formula: {}'.format(prim.get_chemical_formula()))
    '  Cell:' + '\n   [{:9.5f} {:9.5f} {:9.5f}]'*3).format(
        *prim.cell[0], *prim.cell[1], *prim.cell[2]))
    '  Basis:')
    if len(prim) < 5:
        for symbol, spos in zip(prim.get_chemical_symbols(), prim.basis):
  '    {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(symbol, *spos))
    else:
        for sym, spos in zip(prim[:3].get_chemical_symbols(), prim[:3].basis):
  '    {:2} [{:9.5f} {:9.5f} {:9.5f}]'.format(sym, *spos))
  '    ...')
  '')

    cs._prim = prim


def _create_symmetry_dataset(cs): prim_as_tuple = ase_atoms_to_spglib_tuple(cs.primitive_structure)
    symmetry_dataset = spg.get_symmetry_dataset(prim_as_tuple,
                                                 symprec=cs.symprec)
    cs._symmetry_dataset = symmetry_dataset

'Crystal symmetry:')
    '  Spacegroup: {}'.format(cs.spacegroup))
    '  Unique site: {}'.format(len(set(cs.wyckoff_sites))))
    '  Symmetry operations: {}'.format(len(cs.rotation_matrices)))
    '  symprec: {:.2e}'.format(cs.symprec))
  '')


def _create_atom_list(cs):
    tol = cs.symprec
    atom_list = BiMap()

    # Populating the atom list with the center atoms
    for i in range(len(cs._prim)):
        atom_list.append(Atom(i, [0, 0, 0]))

'Cutoffs:')
    '  Maximum cutoff: {}'.format(cs.cutoffs.max_cutoff))

    # Find all the atoms which is neighbors to the atoms in the center cell
    # The pair cutoff should be larger or equal than the others
    cutoffs = [(cs.cutoffs.max_cutoff - tol) / 2] * len(cs._prim) nl = NeighborList(cutoffs=cutoffs, skin=0, self_interaction=True, bothways=True) nl.update(cs._prim) for i in range(len(cs._prim)): for index, offset in zip(*nl.get_neighbors(i)): atom = Atom(index, offset) if atom not in atom_list: atom_list.append(atom) nl = NeighborList( cutoffs=[(cs.cutoffs.max_cutoff + tol) / 2] * len(cs._prim), skin=0, self_interaction=True, bothways=True) nl.update(cs._prim) distance_from_cutoff = tol for i in range(len(cs._prim)): for index, offset in zip(*nl.get_neighbors(i)): atom = Atom(index, offset) # ... and check that no new atom is found if atom not in atom_list: pos = atom.pos(cs._prim.basis, cs._prim.cell) distance = min(np.linalg.norm(pos - atom.position) for atom in cs._prim) - cs.cutoffs.max_cutoff distance_from_cutoff = min(distance, distance_from_cutoff) if distance_from_cutoff != tol: raise Exception('Maximum cutoff close to neighbor shell, change cutoff') msg = ' Found {} center atom{} with {} images totaling {} atoms'.format( len(cs._prim), 's' if len(cs._prim) > 1 else '', len(atom_list) - len(cs._prim), len(atom_list))'') cs._atom_list = atom_list # TODO: add atoms property to cs # TODO: Only inputs are prim, atom_list and cutoffs def _create_cluster_list(cs): # Convert the atom list from site/offset to scaled positions spos = [a.spos(cs._prim.basis) for a in cs._atom_list] numbers = [cs._prim.numbers[] for a in cs._atom_list] # Make an atoms object out of the scaled positions atoms = Atoms(cell=cs._prim.cell, scaled_positions=spos, numbers=numbers, pbc=False) cs._cluster_filter.setup(atoms) cs._cluster_list = get_clusters(atoms, cs.cutoffs, len(cs._prim))'Clusters:') counter = Counter(len(c) for c in cs._cluster_list)' Clusters: {}'.format(dict(counter)))' Total number of clusters: {}\n'.format(sum(counter.values()))) def _create_orbits(cs): # TODO: Check scaled/cart cs._orbits = get_orbits(cs._cluster_list, cs._atom_list, cs.rotation_matrices, cs.translation_vectors, cs.permutations, cs._prim, cs.symprec) orbits_to_drop = [] for i, orbit in enumerate(cs.orbits): if not cs._cluster_filter(cs._cluster_list[orbit.prototype_index]): orbits_to_drop.append(i) reduced_orbits = [] cs._dropped_orbits = [] for i in range(len(cs.orbits)): if i in orbits_to_drop: cs._dropped_orbits.append(cs.orbits[i]) else: reduced_orbits.append(cs.orbits[i]) cs._orbits = reduced_orbits counter = Counter(orb.order for orb in cs._orbits)'Orbits:')' Orbits: {}'.format(dict(counter)))' Total number of orbits: {}\n'.format(sum(counter.values()))) def _drop_orbits(cs): orbits_to_drop = [] for i, orbit in enumerate(cs.orbits): if not orbit.eigentensors: orbits_to_drop.append(i) reduced_orbits = [] for i in range(len(cs.orbits)): if i in orbits_to_drop: cs._dropped_orbits.append(cs.orbits[i]) else: reduced_orbits.append(cs.orbits[i]) cs._orbits = reduced_orbits'Eigentensors:') n_ets = dict() for order in cs.cutoffs.orders: n_ets[order] = sum(len(orb.eigentensors) for orb in cs.orbits if orb.order == order)' Eigentensors: {}'.format(n_ets))' Total number of parameters: {}'.format(sum(n_ets.values()))) if len(cs._dropped_orbits) > 0:' Discarded orbits:') for orb in cs._dropped_orbits:' {}'.format(cs.cluster_list[orb.prototype_index])) counter = Counter(orb.order for orb in cs._orbits) for order in cs.cutoffs.orders: if counter[order] == 0: logger.warning(' Warning: No orbits exists for order {}'.format(order))'') def _populate_ofs_with_ets(cs): R_inv_lookup = dict() for orbit_index, orbit in enumerate(cs.orbits): for of in orbit.orientation_families: R_inv_lookup_index = (of.symmetry_index, orbit.order) R_inv = R_inv_lookup.get(R_inv_lookup_index, None) if R_inv is None: R = cs.rotation_matrices[of.symmetry_index] R_inv_tmp = np.linalg.inv(R) R_inv = R_inv_tmp.astype(np.int64) assert np.allclose(R_inv, R_inv_tmp), (R_inv, R_inv_tmp) R_inv = rotation_tensor_as_matrix(R_inv, orbit.order) R_inv_lookup[R_inv_lookup_index] = R_inv of.eigentensors = [] for et in orbit.eigentensors: rotated_et = rotate_tensor_precalc(et, R_inv) assert rotated_et.dtype == np.int64 of.eigentensors.append(rotated_et) def _rotate_eigentensors(cs): V_invT = np.linalg.inv(cs._prim.cell.T) lookup = dict() for orb in cs.orbits: V_invT_tensormatrix = lookup.get(orb.order, None) if V_invT_tensormatrix is None: V_invT_tensormatrix = rotation_tensor_as_matrix(V_invT, orb.order) lookup[orb.order] = V_invT_tensormatrix orb.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in orb.eigentensors] # noqa for of in orb.orientation_families: of.eigentensors = [rotate_tensor_precalc(et, V_invT_tensormatrix) for et in of.eigentensors] # noqa def _normalize_constraint_vectors(cs): M = cs._cvs norms = np.zeros(M.shape[1]) for c, v in zip(M.col, norms[c] += v**2 for i in range(len(norms)): norms[i] = np.sqrt(norms[i]) for i, c in enumerate(M.col):[i] /= norms[c] def _rescale_eigentensors(cs): for orbit in cs.orbits: norm = cs.length_scale**orbit.order ets = orbit.eigentensors for i, et in enumerate(ets): ets[i] = et.astype(np.float64) ets[i] /= norm for of in orbit.orientation_families: ets = of.eigentensors for i, et in enumerate(ets): ets[i] = et.astype(np.float64) ets[i] /= norm orders = [] for orbit in cs.orbits: for et in orbit.eigentensors: orders.append(orbit.order) M = cs._cvs for i, r in enumerate(M.row):[i] *= cs.length_scale**orders[r]