Skip to main content

sci_form/ani/
aev.rs

1//! Behler-Parrinello Atomic Environment Vectors (AEVs).
2//!
3//! Transforms Cartesian coordinates into rotation/translation-invariant
4//! descriptors using radial and angular symmetry functions.
5//!
6//! Radial: $G_R = \sum_{j} e^{-\eta(R_{ij}-R_s)^2} f_c(R_{ij})$
7//! Angular: $G_A = 2^{1-\zeta} \sum_{j<k} (1+\cos(\theta-\theta_s))^\zeta
8//!          \cdot e^{-\eta((R_{ij}+R_{ik})/2 - R_s)^2} f_c(R_{ij}) f_c(R_{ik})$
9
10use super::aev_params::{species_index, AevParams, N_SPECIES};
11use super::cutoff::cosine_cutoff;
12use super::neighbor::NeighborPair;
13
14/// Compute AEVs for all atoms in the system.
15///
16/// Returns a `Vec` of length `n_atoms`, each entry is an AEV vector.
17pub fn compute_aevs(
18    elements: &[u8],
19    positions: &[[f64; 3]],
20    neighbors: &[NeighborPair],
21    params: &AevParams,
22) -> Vec<Vec<f64>> {
23    let n = elements.len();
24    let aev_len = params.total_aev_length();
25
26    // Build per-atom neighbor lists (both directions)
27    let mut atom_neighbors: Vec<Vec<(usize, f64)>> = vec![Vec::new(); n];
28    for np in neighbors {
29        let d = np.dist_sq.sqrt();
30        atom_neighbors[np.i].push((np.j, d));
31        atom_neighbors[np.j].push((np.i, d));
32    }
33
34    #[cfg(feature = "parallel")]
35    {
36        use rayon::prelude::*;
37        (0..n)
38            .into_par_iter()
39            .map(|i| {
40                let mut aev = vec![0.0f64; aev_len];
41                if species_index(elements[i]).is_some() {
42                    compute_radial_aev(
43                        i,
44                        elements,
45                        positions,
46                        &atom_neighbors[i],
47                        params,
48                        &mut aev,
49                    );
50                    compute_angular_aev(
51                        i,
52                        elements,
53                        positions,
54                        &atom_neighbors[i],
55                        params,
56                        &mut aev,
57                    );
58                }
59                aev
60            })
61            .collect()
62    }
63    #[cfg(not(feature = "parallel"))]
64    {
65        let mut aevs = vec![vec![0.0f64; aev_len]; n];
66        for i in 0..n {
67            if species_index(elements[i]).is_none() {
68                continue;
69            }
70            compute_radial_aev(
71                i,
72                elements,
73                positions,
74                &atom_neighbors[i],
75                params,
76                &mut aevs[i],
77            );
78            compute_angular_aev(
79                i,
80                elements,
81                positions,
82                &atom_neighbors[i],
83                params,
84                &mut aevs[i],
85            );
86        }
87        aevs
88    }
89}
90
91fn compute_radial_aev(
92    _i: usize,
93    elements: &[u8],
94    _positions: &[[f64; 3]],
95    neighbors_i: &[(usize, f64)],
96    params: &AevParams,
97    aev: &mut [f64],
98) {
99    let rad_len = params.radial_length();
100
101    for &(j, rij) in neighbors_i {
102        if rij >= params.radial_cutoff {
103            continue;
104        }
105        let sj = match species_index(elements[j]) {
106            Some(s) => s,
107            None => continue,
108        };
109        let fc = cosine_cutoff(rij, params.radial_cutoff);
110        let offset = sj * rad_len;
111
112        let mut k = 0;
113        for eta in &params.radial_eta {
114            for rs in &params.radial_rs {
115                let dr = rij - rs;
116                aev[offset + k] += (-eta * dr * dr).exp() * fc;
117                k += 1;
118            }
119        }
120    }
121}
122
123fn compute_angular_aev(
124    i: usize,
125    elements: &[u8],
126    positions: &[[f64; 3]],
127    neighbors_i: &[(usize, f64)],
128    params: &AevParams,
129    aev: &mut [f64],
130) {
131    let rad_total = N_SPECIES * params.radial_length();
132    let ang_len = params.angular_length();
133
134    // Filter neighbors within angular cutoff
135    let ang_neighbors: Vec<(usize, f64)> = neighbors_i
136        .iter()
137        .filter(|&&(_, d)| d < params.angular_cutoff)
138        .copied()
139        .collect();
140
141    for a in 0..ang_neighbors.len() {
142        let (j, rij) = ang_neighbors[a];
143        let sj = match species_index(elements[j]) {
144            Some(s) => s,
145            None => continue,
146        };
147        let fc_ij = cosine_cutoff(rij, params.angular_cutoff);
148
149        for b in (a + 1)..ang_neighbors.len() {
150            let (k, rik) = ang_neighbors[b];
151            let sk = match species_index(elements[k]) {
152                Some(s) => s,
153                None => continue,
154            };
155            let fc_ik = cosine_cutoff(rik, params.angular_cutoff);
156
157            let theta = compute_angle(positions, i, j, k);
158            let (s_lo, s_hi) = if sj <= sk { (sj, sk) } else { (sk, sj) };
159            let pair_idx = s_lo * (2 * N_SPECIES - s_lo - 1) / 2 + (s_hi - s_lo);
160            let offset = rad_total + pair_idx * ang_len;
161
162            let r_avg = (rij + rik) / 2.0;
163            let mut m = 0;
164            for eta in &params.angular_eta {
165                for rs in &params.angular_rs {
166                    for zeta in &params.angular_zeta {
167                        for theta_s in &params.angular_theta_s {
168                            let cos_term = 1.0 + (theta - theta_s).cos();
169                            let angular = 2.0f64.powf(1.0 - zeta) * cos_term.powf(*zeta);
170                            let radial = (-eta * (r_avg - rs).powi(2)).exp();
171                            aev[offset + m] += angular * radial * fc_ij * fc_ik;
172                            m += 1;
173                        }
174                    }
175                }
176            }
177        }
178    }
179}
180
181/// Compute angle ∠jik from positions.
182fn compute_angle(positions: &[[f64; 3]], i: usize, j: usize, k: usize) -> f64 {
183    let vij = [
184        positions[j][0] - positions[i][0],
185        positions[j][1] - positions[i][1],
186        positions[j][2] - positions[i][2],
187    ];
188    let vik = [
189        positions[k][0] - positions[i][0],
190        positions[k][1] - positions[i][1],
191        positions[k][2] - positions[i][2],
192    ];
193    let dot = vij[0] * vik[0] + vij[1] * vik[1] + vij[2] * vik[2];
194    let nij = (vij[0] * vij[0] + vij[1] * vij[1] + vij[2] * vij[2]).sqrt();
195    let nik = (vik[0] * vik[0] + vik[1] * vik[1] + vik[2] * vik[2]).sqrt();
196    let cos_theta = (dot / (nij * nik)).clamp(-1.0, 1.0);
197    cos_theta.acos()
198}
199
200#[cfg(test)]
201mod tests {
202    use super::super::aev_params::default_ani2x_params;
203    use super::super::neighbor::CellList;
204    use super::*;
205
206    #[test]
207    fn test_aev_water() {
208        let elements = [8u8, 1, 1];
209        let positions = [
210            [0.0, 0.0, 0.117],
211            [0.0, 0.757, -0.469],
212            [0.0, -0.757, -0.469],
213        ];
214        let params = default_ani2x_params();
215        let cl = CellList::new(&positions, params.radial_cutoff);
216        let neighbors = cl.find_neighbors(&positions);
217        let aevs = compute_aevs(&elements, &positions, &neighbors, &params);
218
219        assert_eq!(aevs.len(), 3);
220        // Both H atoms should have symmetric AEVs
221        let diff: f64 = aevs[1]
222            .iter()
223            .zip(aevs[2].iter())
224            .map(|(a, b)| (a - b).abs())
225            .sum();
226        assert!(diff < 1e-10, "H atoms in water should have symmetric AEVs");
227    }
228
229    #[test]
230    fn test_aev_nonzero() {
231        let elements = [6u8, 1, 1, 1, 1];
232        let positions = [
233            [0.0, 0.0, 0.0],
234            [0.63, 0.63, 0.63],
235            [-0.63, -0.63, 0.63],
236            [-0.63, 0.63, -0.63],
237            [0.63, -0.63, -0.63],
238        ];
239        let params = default_ani2x_params();
240        let cl = CellList::new(&positions, params.radial_cutoff);
241        let neighbors = cl.find_neighbors(&positions);
242        let aevs = compute_aevs(&elements, &positions, &neighbors, &params);
243
244        // Carbon AEV should have nonzero entries
245        let sum: f64 = aevs[0].iter().map(|v| v.abs()).sum();
246        assert!(sum > 0.0, "Carbon AEV should be nonzero");
247    }
248}