Coverage for hiphive/core/rotational_constraints.py: 97%
100 statements
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
« prev ^ index » next coverage.py v7.10.1, created at 2025-08-01 17:04 +0000
1"""
2Functionality for enforcing rotational sum rules
3"""
4import numpy as np
5from typing import List
6from sklearn.linear_model import Ridge
7from scipy.sparse import coo_matrix, lil_matrix
8from ..cluster_space import ClusterSpace
11def enforce_rotational_sum_rules(
12 cs: ClusterSpace,
13 parameters: np.ndarray,
14 sum_rules: List[str] = None,
15 alpha: float = 1e-6,
16 **ridge_kwargs: dict,
17) -> np.ndarray:
18 """ Enforces rotational sum rules by projecting parameters.
20 Note
21 ----
22 The interface to this function might change in future releases.
24 Parameters
25 ----------
26 cs
27 The underlying cluster space.
28 parameters
29 Parameters to be constrained.
30 sum_rules
31 Type of sum rules to enforce; possible values: `'Huang'`, `'Born-Huang'`.
32 alpha
33 Hyperparameter of the ridge regression algorithm; keyword argument
34 passed to the optimizer; larger values specify stronger regularization,
35 i.e., less correction but higher stability [default: 1e-6].
36 ridge_kwargs
37 kwargs to be passed to sklearn Ridge.
39 Returns
40 -------
41 Constrained parameters.
43 Examples
44 --------
45 The rotational sum rules can be enforced to the parameters before
46 constructing a force constant potential as illustrated by the following
47 snippet::
49 cs = ClusterSpace(reference_structure, cutoffs)
50 sc = StructureContainer(cs)
51 # add structures to structure container
52 opt = Optimizer(sc.get_fit_data())
53 opt.train()
54 new_params = enforce_rotational_sum_rules(cs, opt.parameters,
55 sum_rules=['Huang', 'Born-Huang'])
56 fcp = ForceConstantPotential(cs, new_params)
58 """
60 all_sum_rules = ['Huang', 'Born-Huang']
62 # setup
63 parameters = parameters.copy()
64 if sum_rules is None: 64 ↛ 68line 64 didn't jump to line 68 because the condition on line 64 was always true
65 sum_rules = all_sum_rules
67 # get constraint-matrix
68 M = get_rotational_constraint_matrix(cs, sum_rules)
70 # before fit
71 d = M.dot(parameters)
72 delta = np.linalg.norm(d)
73 print('Rotational sum-rules before, ||Ax|| = {:20.15f}'.format(delta))
75 # fitting
76 ridge = Ridge(alpha=alpha, fit_intercept=False, solver='sparse_cg', **ridge_kwargs)
77 ridge.fit(M, d)
78 parameters -= ridge.coef_
80 # after fit
81 d = M.dot(parameters)
82 delta = np.linalg.norm(d)
83 print('Rotational sum-rules after, ||Ax|| = {:20.15f}'.format(delta))
85 return parameters
88def get_rotational_constraint_matrix(cs, sum_rules=None):
90 all_sum_rules = ['Huang', 'Born-Huang']
92 if sum_rules is None:
93 sum_rules = all_sum_rules
95 # setup
96 assert len(sum_rules) > 0
97 for s in sum_rules:
98 if s not in all_sum_rules: 98 ↛ 99line 98 didn't jump to line 99 because the condition on line 98 was never true
99 raise ValueError('Sum rule {} not allowed, select from {}'.format(s, all_sum_rules))
101 # make orbit-parameter index map
102 params = _get_orbit_parameter_map(cs)
103 lookup = _get_fc_lookup(cs)
105 # append the sum rule matrices
106 Ms = []
107 args = (lookup, params, cs.atom_list, cs._prim)
108 for sum_rule in sum_rules:
109 if sum_rule == 'Huang':
110 Ms.append(_create_Huang_constraint(*args))
111 elif sum_rule == 'Born-Huang': 111 ↛ 108line 111 didn't jump to line 108 because the condition on line 111 was always true
112 Ms.append(_create_Born_Huang_constraint(*args))
114 # transform and stack matrices
115 cvs_trans = cs._cvs
116 for i, M in enumerate(Ms):
117 M = M.dot(cvs_trans)
118 M = M.toarray()
119 Ms[i] = M
121 return np.vstack(Ms)
124def _get_orbit_parameter_map(cs):
125 # make orbit-parameter index map
126 params = []
127 n = 0
128 for orbit_index, orbit in enumerate(cs.orbits):
129 n_params_in_orbit = len(orbit.eigentensors)
130 params.append(list(range(n, n + n_params_in_orbit)))
131 n += n_params_in_orbit
132 return params
135def _get_fc_lookup(cs):
136 # create lookuptable for force constants
137 lookup = {}
138 for orbit_index, orbit in enumerate(cs.orbits):
139 for of in orbit.orientation_families:
140 for cluster_index, perm_index in zip(of.cluster_indices, of.permutation_indices):
141 cluster = cs._cluster_list[cluster_index]
142 perm = cs._permutations[perm_index]
143 lookup[tuple(cluster)] = [et.transpose(perm) for et in of.eigentensors], orbit_index
144 return lookup
147def _create_Huang_constraint(lookup, parameter_map, atom_list, prim):
149 m = np.zeros((parameter_map[-1][-1] + 1, 3**4))
151 def R(i, j):
152 pi = atom_list[i].pos(prim.basis, prim.cell)
153 pj = atom_list[j].pos(prim.basis, prim.cell)
154 return pi - pj
156 for i in range(len(prim)):
157 for j in range(len(atom_list)):
158 ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
159 if ets is None:
160 continue
161 inv_perm = np.argsort(np.argsort((i, j)))
162 et_indices = parameter_map[orbit_index]
163 for et, et_index in zip(ets, et_indices):
164 et = et.transpose(inv_perm)
165 Rij = R(i, j)
166 Cij = np.einsum(et, [0, 1], Rij, [2], Rij, [3])
167 Cij -= Cij.transpose([2, 3, 0, 1])
168 m[et_index] += Cij.flat
170 m = coo_matrix(m.transpose())
171 return m
174def _create_Born_Huang_constraint(lookup, parameter_map, atom_list, prim):
176 # Use scipy list-of-lists sparse matrix for good tradeoff between
177 # restructuring, indexing/slicing and memory footprint
178 M = lil_matrix((len(prim) * 3**3, parameter_map[-1][-1] + 1))
179 for i in range(len(prim)):
180 # Use smaller numpy arrays for speedy arithmetic
181 m = np.zeros((parameter_map[-1][-1] + 1, 3**3))
182 for j in range(len(atom_list)):
183 ets, orbit_index = lookup.get(tuple(sorted((i, j))), (None, None))
184 if ets is None:
185 continue
186 inv_perm = np.argsort(np.argsort((i, j)))
187 et_indices = parameter_map[orbit_index]
188 R = atom_list[j].pos(prim.basis, prim.cell)
189 for et, et_index in zip(ets, et_indices):
190 et = et.transpose(inv_perm)
191 tmp = np.einsum(et, [0, 1], R, [2])
192 tmp -= tmp.transpose([0, 2, 1])
193 m[et_index] += tmp.flat
194 M[i*3**3:(i+1)*3**3, :] = m.transpose()
196 # Convert lil_matrix to coo_matrix
197 return M.tocoo()