Coverage for hiphive/core/rotational_constraints.py: 97%
100 statements
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
« prev ^ index » next coverage.py v7.6.8, created at 2024-11-28 11:20 +0000
1"""
2Functionality for enforcing rotational sum rules
3"""
4import numpy as np
5from typing import List
6from sklearn.linear_model import Ridge
7from scipy.sparse import coo_matrix, lil_matrix
8from ..cluster_space import ClusterSpace
11def enforce_rotational_sum_rules(cs: ClusterSpace,
12 parameters: np.ndarray,
13 sum_rules: List[str] = None,
14 alpha: float = 1e-6,
15 **ridge_kwargs: dict) -> np.ndarray:
16 """ Enforces rotational sum rules by projecting parameters.
18 Note
19 ----
20 The interface to this function might change in future releases.
22 Parameters
23 ----------
24 cs
25 the underlying cluster space
26 parameters
27 parameters to be constrained
28 sum_rules
29 type of sum rules to enforce; possible values: 'Huang', 'Born-Huang'
30 alpha
31 hyperparameter to the ridge regression algorithm; keyword argument
32 passed to the optimizer; larger values specify stronger regularization,
33 i.e., less correction but higher stability [default: 1e-6]
34 ridge_kwargs
35 kwargs to be passed to sklearn Ridge
37 Returns
38 -------
39 numpy.ndarray
40 constrained parameters
42 Examples
43 --------
44 The rotational sum rules can be enforced to the parameters before
45 constructing a force constant potential as illustrated by the following
46 snippet::
48 cs = ClusterSpace(reference_structure, cutoffs)
49 sc = StructureContainer(cs)
50 # add structures to structure container
51 opt = Optimizer(sc.get_fit_data())
52 opt.train()
53 new_params = enforce_rotational_sum_rules(cs, opt.parameters,
54 sum_rules=['Huang', 'Born-Huang'])
55 fcp = ForceConstantPotential(cs, new_params)
57 """
59 all_sum_rules = ['Huang', 'Born-Huang']
61 # setup
62 parameters = parameters.copy()
63 if sum_rules is None: 63 ↛ 67line 63 didn't jump to line 67 because the condition on line 63 was always true
64 sum_rules = all_sum_rules
66 # get constraint-matrix
67 M = get_rotational_constraint_matrix(cs, sum_rules)
69 # before fit
70 d = M.dot(parameters)
71 delta = np.linalg.norm(d)
72 print('Rotational sum-rules before, ||Ax|| = {:20.15f}'.format(delta))
74 # fitting
75 ridge = Ridge(alpha=alpha, fit_intercept=False, solver='sparse_cg', **ridge_kwargs)
76 ridge.fit(M, d)
77 parameters -= ridge.coef_
79 # after fit
80 d = M.dot(parameters)
81 delta = np.linalg.norm(d)
82 print('Rotational sum-rules after, ||Ax|| = {:20.15f}'.format(delta))
84 return parameters
87def get_rotational_constraint_matrix(cs, sum_rules=None):
89 all_sum_rules = ['Huang', 'Born-Huang']
91 if sum_rules is None:
92 sum_rules = all_sum_rules
94 # setup
95 assert len(sum_rules) > 0
96 for s in sum_rules:
97 if s not in all_sum_rules: 97 ↛ 98line 97 didn't jump to line 98 because the condition on line 97 was never true
98 raise ValueError('Sum rule {} not allowed, select from {}'.format(s, all_sum_rules))
100 # make orbit-parameter index map
101 params = _get_orbit_parameter_map(cs)
102 lookup = _get_fc_lookup(cs)
104 # append the sum rule matrices
105 Ms = []
106 args = (lookup, params, cs.atom_list, cs._prim)
107 for sum_rule in sum_rules:
108 if sum_rule == 'Huang':
109 Ms.append(_create_Huang_constraint(*args))
110 elif sum_rule == 'Born-Huang': 110 ↛ 107line 110 didn't jump to line 107 because the condition on line 110 was always true
111 Ms.append(_create_Born_Huang_constraint(*args))
113 # transform and stack matrices
114 cvs_trans = cs._cvs
115 for i, M in enumerate(Ms):
116 M = M.dot(cvs_trans)
117 M = M.toarray()
118 Ms[i] = M
120 return np.vstack(Ms)
123def _get_orbit_parameter_map(cs):
124 # make orbit-parameter index map
125 params = []
126 n = 0
127 for orbit_index, orbit in enumerate(cs.orbits):
128 n_params_in_orbit = len(orbit.eigentensors)
129 params.append(list(range(n, n + n_params_in_orbit)))
130 n += n_params_in_orbit
131 return params
134def _get_fc_lookup(cs):
135 # create lookuptable for force constants
136 lookup = {}
137 for orbit_index, orbit in enumerate(cs.orbits):
138 for of in orbit.orientation_families:
139 for cluster_index, perm_index in zip(of.cluster_indices, of.permutation_indices):
140 cluster = cs._cluster_list[cluster_index]
141 perm = cs._permutations[perm_index]
142 lookup[tuple(cluster)] = [et.transpose(perm) for et in of.eigentensors], orbit_index
143 return lookup
146def _create_Huang_constraint(lookup, parameter_map, atom_list, prim):
148 m = np.zeros((parameter_map[-1][-1] + 1, 3**4))
150 def R(i, j):
151 pi = atom_list[i].pos(prim.basis, prim.cell)
152 pj = atom_list[j].pos(prim.basis, prim.cell)
153 return pi - pj
155 for i in range(len(prim)):
156 for j in range(len(atom_list)):
157 ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
158 if ets is None:
159 continue
160 inv_perm = np.argsort(np.argsort((i, j)))
161 et_indices = parameter_map[orbit_index]
162 for et, et_index in zip(ets, et_indices):
163 et = et.transpose(inv_perm)
164 Rij = R(i, j)
165 Cij = np.einsum(et, [0, 1], Rij, [2], Rij, [3])
166 Cij -= Cij.transpose([2, 3, 0, 1])
167 m[et_index] += Cij.flat
169 m = coo_matrix(m.transpose())
170 return m
173def _create_Born_Huang_constraint(lookup, parameter_map, atom_list, prim):
175 # Use scipy list-of-lists sparse matrix for good tradeoff between
176 # restructuring, indexing/slicing and memory footprint
177 M = lil_matrix((len(prim) * 3**3, parameter_map[-1][-1] + 1))
178 for i in range(len(prim)):
179 # Use smaller numpy arrays for speedy arithmetic
180 m = np.zeros((parameter_map[-1][-1] + 1, 3**3))
181 for j in range(len(atom_list)):
182 ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
183 if ets is None:
184 continue
185 inv_perm = np.argsort(np.argsort((i, j)))
186 et_indices = parameter_map[orbit_index]
187 R = atom_list[j].pos(prim.basis, prim.cell)
188 for et, et_index in zip(ets, et_indices):
189 et = et.transpose(inv_perm)
190 tmp = np.einsum(et, [0, 1], R, [2])
191 tmp -= tmp.transpose([0, 2, 1])
192 m[et_index] += tmp.flat
193 M[i*3**3:(i+1)*3**3, :] = m.transpose()
195 # Convert lil_matrix to coo_matrix
196 return M.tocoo()