1use super::aev_params::{species_index, AevParams, N_SPECIES};
11use super::cutoff::cosine_cutoff;
12use super::neighbor::NeighborPair;
13
14pub fn compute_aevs(
18 elements: &[u8],
19 positions: &[[f64; 3]],
20 neighbors: &[NeighborPair],
21 params: &AevParams,
22) -> Vec<Vec<f64>> {
23 let n = elements.len();
24 let aev_len = params.total_aev_length();
25
26 let mut atom_neighbors: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
28 for np in neighbors {
29 let d = np.dist_sq.sqrt();
30 atom_neighbors[np.i].push((np.j, d));
31 atom_neighbors[np.j].push((np.i, d));
32 }
33
34 #[cfg(feature = "parallel")]
35 {
36 use rayon::prelude::*;
37 (0..n)
38 .into_par_iter()
39 .map(|i| {
40 let mut aev = vec![0.0f64; aev_len];
41 if species_index(elements[i]).is_some() {
42 compute_radial_aev(
43 i,
44 elements,
45 positions,
46 &atom_neighbors[i],
47 params,
48 &mut aev,
49 );
50 compute_angular_aev(
51 i,
52 elements,
53 positions,
54 &atom_neighbors[i],
55 params,
56 &mut aev,
57 );
58 }
59 aev
60 })
61 .collect()
62 }
63 #[cfg(not(feature = "parallel"))]
64 {
65 let mut aevs = vec![vec![0.0f64; aev_len]; n];
66 for i in 0..n {
67 if species_index(elements[i]).is_none() {
68 continue;
69 }
70 compute_radial_aev(
71 i,
72 elements,
73 positions,
74 &atom_neighbors[i],
75 params,
76 &mut aevs[i],
77 );
78 compute_angular_aev(
79 i,
80 elements,
81 positions,
82 &atom_neighbors[i],
83 params,
84 &mut aevs[i],
85 );
86 }
87 aevs
88 }
89}
90
91fn compute_radial_aev(
92 _i: usize,
93 elements: &[u8],
94 _positions: &[[f64; 3]],
95 neighbors_i: &[(usize, f64)],
96 params: &AevParams,
97 aev: &mut [f64],
98) {
99 let rad_len = params.radial_length();
100
101 for &(j, rij) in neighbors_i {
102 if rij >= params.radial_cutoff {
103 continue;
104 }
105 let sj = match species_index(elements[j]) {
106 Some(s) => s,
107 None => continue,
108 };
109 let fc = cosine_cutoff(rij, params.radial_cutoff);
110 let offset = sj * rad_len;
111
112 let mut k = 0;
113 for eta in ¶ms.radial_eta {
114 for rs in ¶ms.radial_rs {
115 let dr = rij - rs;
116 aev[offset + k] += (-eta * dr * dr).exp() * fc;
117 k += 1;
118 }
119 }
120 }
121}
122
123fn compute_angular_aev(
124 i: usize,
125 elements: &[u8],
126 positions: &[[f64; 3]],
127 neighbors_i: &[(usize, f64)],
128 params: &AevParams,
129 aev: &mut [f64],
130) {
131 let rad_total = N_SPECIES * params.radial_length();
132 let ang_len = params.angular_length();
133
134 let ang_neighbors: Vec<(usize, f64)> = neighbors_i
136 .iter()
137 .filter(|&&(_, d)| d < params.angular_cutoff)
138 .copied()
139 .collect();
140
141 for a in 0..ang_neighbors.len() {
142 let (j, rij) = ang_neighbors[a];
143 let sj = match species_index(elements[j]) {
144 Some(s) => s,
145 None => continue,
146 };
147 let fc_ij = cosine_cutoff(rij, params.angular_cutoff);
148
149 for b in (a + 1)..ang_neighbors.len() {
150 let (k, rik) = ang_neighbors[b];
151 let sk = match species_index(elements[k]) {
152 Some(s) => s,
153 None => continue,
154 };
155 let fc_ik = cosine_cutoff(rik, params.angular_cutoff);
156
157 let theta = compute_angle(positions, i, j, k);
158 let (s_lo, s_hi) = if sj <= sk { (sj, sk) } else { (sk, sj) };
159 let pair_idx = s_lo * (2 * N_SPECIES - s_lo - 1) / 2 + (s_hi - s_lo);
160 let offset = rad_total + pair_idx * ang_len;
161
162 let r_avg = (rij + rik) / 2.0;
163 let mut m = 0;
164 for eta in ¶ms.angular_eta {
165 for rs in ¶ms.angular_rs {
166 for zeta in ¶ms.angular_zeta {
167 for theta_s in ¶ms.angular_theta_s {
168 let cos_term = 1.0 + (theta - theta_s).cos();
169 let angular = 2.0f64.powf(1.0 - zeta) * cos_term.powf(*zeta);
170 let radial = (-eta * (r_avg - rs).powi(2)).exp();
171 aev[offset + m] += angular * radial * fc_ij * fc_ik;
172 m += 1;
173 }
174 }
175 }
176 }
177 }
178 }
179}
180
181fn compute_angle(positions: &[[f64; 3]], i: usize, j: usize, k: usize) -> f64 {
183 let vij = [
184 positions[j][0] - positions[i][0],
185 positions[j][1] - positions[i][1],
186 positions[j][2] - positions[i][2],
187 ];
188 let vik = [
189 positions[k][0] - positions[i][0],
190 positions[k][1] - positions[i][1],
191 positions[k][2] - positions[i][2],
192 ];
193 let dot = vij[0] * vik[0] + vij[1] * vik[1] + vij[2] * vik[2];
194 let nij = (vij[0] * vij[0] + vij[1] * vij[1] + vij[2] * vij[2]).sqrt();
195 let nik = (vik[0] * vik[0] + vik[1] * vik[1] + vik[2] * vik[2]).sqrt();
196 let cos_theta = (dot / (nij * nik)).clamp(-1.0, 1.0);
197 cos_theta.acos()
198}
199
200#[cfg(test)]
201mod tests {
202 use super::super::aev_params::default_ani2x_params;
203 use super::super::neighbor::CellList;
204 use super::*;
205
206 #[test]
207 fn test_aev_water() {
208 let elements = [8u8, 1, 1];
209 let positions = [
210 [0.0, 0.0, 0.117],
211 [0.0, 0.757, -0.469],
212 [0.0, -0.757, -0.469],
213 ];
214 let params = default_ani2x_params();
215 let cl = CellList::new(&positions, params.radial_cutoff);
216 let neighbors = cl.find_neighbors(&positions);
217 let aevs = compute_aevs(&elements, &positions, &neighbors, ¶ms);
218
219 assert_eq!(aevs.len(), 3);
220 let diff: f64 = aevs[1]
222 .iter()
223 .zip(aevs[2].iter())
224 .map(|(a, b)| (a - b).abs())
225 .sum();
226 assert!(diff < 1e-10, "H atoms in water should have symmetric AEVs");
227 }
228
229 #[test]
230 fn test_aev_nonzero() {
231 let elements = [6u8, 1, 1, 1, 1];
232 let positions = [
233 [0.0, 0.0, 0.0],
234 [0.63, 0.63, 0.63],
235 [-0.63, -0.63, 0.63],
236 [-0.63, 0.63, -0.63],
237 [0.63, -0.63, -0.63],
238 ];
239 let params = default_ani2x_params();
240 let cl = CellList::new(&positions, params.radial_cutoff);
241 let neighbors = cl.find_neighbors(&positions);
242 let aevs = compute_aevs(&elements, &positions, &neighbors, ¶ms);
243
244 let sum: f64 = aevs[0].iter().map(|v| v.abs()).sum();
246 assert!(sum > 0.0, "Carbon AEV should be nonzero");
247 }
248}