Skip to main content

sci_form/dispersion/
dispersion.rs

1//! D4 Dispersion Energy and Gradients — Core
2//!
3//! Two-body BJ-damped dispersion with optional three-body ATM term.
4
5use super::params::{c8_from_c6, d4_coordination_number, dynamic_c6};
6use serde::{Deserialize, Serialize};
7
8/// D4 configuration.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct D4Config {
11    /// s6 scaling factor (usually 1.0).
12    pub s6: f64,
13    /// s8 scaling factor.
14    pub s8: f64,
15    /// BJ damping a1.
16    pub a1: f64,
17    /// BJ damping a2 (Bohr).
18    pub a2: f64,
19    /// Whether to include three-body ATM term.
20    pub three_body: bool,
21    /// s9 scaling for ATM (usually 1.0).
22    pub s9: f64,
23}
24
25impl Default for D4Config {
26    fn default() -> Self {
27        Self {
28            s6: 1.0,
29            s8: 0.95,
30            a1: 0.45,
31            a2: 4.0,
32            three_body: false,
33            s9: 1.0,
34        }
35    }
36}
37
38/// Result of D4 dispersion calculation.
39///
40/// Energy terms are in Hartree; `total_kcal_mol` provides the
41/// convenience conversion. Two-body (`e2_body`) and three-body
42/// (`e3_body`) terms are reported separately for transparency.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct D4Result {
45    /// Two-body dispersion energy (Hartree).
46    pub e2_body: f64,
47    /// Three-body ATM energy (Hartree).
48    pub e3_body: f64,
49    /// Total dispersion energy (Hartree).
50    pub total_energy: f64,
51    /// Total in kcal/mol.
52    pub total_kcal_mol: f64,
53    /// Per-atom coordination numbers.
54    pub coordination_numbers: Vec<f64>,
55}
56
57/// Compute D4 dispersion energy.
58pub fn compute_d4_energy(elements: &[u8], positions: &[[f64; 3]], config: &D4Config) -> D4Result {
59    let n = elements.len();
60    let cn = d4_coordination_number(elements, positions);
61    let ang_to_bohr = 1.0 / 0.529177;
62
63    #[cfg(feature = "parallel")]
64    let e2 = {
65        use rayon::prelude::*;
66        (0..n)
67            .into_par_iter()
68            .map(|i| {
69                ((i + 1)..n)
70                    .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
71                    .sum::<f64>()
72            })
73            .sum::<f64>()
74    };
75
76    #[cfg(not(feature = "parallel"))]
77    let e2 = (0..n)
78        .map(|i| {
79            ((i + 1)..n)
80                .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
81                .sum::<f64>()
82        })
83        .sum::<f64>();
84
85    #[cfg(feature = "parallel")]
86    let e3 = if config.three_body && n >= 3 {
87        use rayon::prelude::*;
88        (0..n)
89            .into_par_iter()
90            .map(|i| {
91                let mut subtotal = 0.0;
92                for j in (i + 1)..n {
93                    for k in (j + 1)..n {
94                        subtotal +=
95                            triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
96                    }
97                }
98                subtotal
99            })
100            .sum::<f64>()
101    } else {
102        0.0
103    };
104
105    #[cfg(not(feature = "parallel"))]
106    let e3 = if config.three_body && n >= 3 {
107        let mut total = 0.0;
108        for i in 0..n {
109            for j in (i + 1)..n {
110                for k in (j + 1)..n {
111                    total += triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
112                }
113            }
114        }
115        total
116    } else {
117        0.0
118    };
119
120    let total = e2 + e3;
121    let hartree_to_kcal = 627.509;
122
123    D4Result {
124        e2_body: e2,
125        e3_body: e3,
126        total_energy: total,
127        total_kcal_mol: total * hartree_to_kcal,
128        coordination_numbers: cn,
129    }
130}
131
132fn pair_energy(
133    i: usize,
134    j: usize,
135    elements: &[u8],
136    positions: &[[f64; 3]],
137    cn: &[f64],
138    config: &D4Config,
139    ang_to_bohr: f64,
140) -> f64 {
141    let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
142    let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
143    let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
144    let r = (dx * dx + dy * dy + dz * dz).sqrt();
145
146    if r < 1e-10 {
147        return 0.0;
148    }
149
150    let c6 = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
151    let c8 = c8_from_c6(c6, elements[i], elements[j]);
152    let r0 = if c6 > 1e-10 { (c8 / c6).sqrt() } else { 5.0 };
153    let r_cut = config.a1 * r0 + config.a2;
154
155    let r6 = r.powi(6);
156    let damp6 = r6 / (r6 + r_cut.powi(6));
157    let term6 = -config.s6 * c6 / r6 * damp6;
158
159    let r8 = r.powi(8);
160    let damp8 = r8 / (r8 + r_cut.powi(8));
161    let term8 = -config.s8 * c8 / r8 * damp8;
162
163    term6 + term8
164}
165
166fn triple_energy(
167    i: usize,
168    j: usize,
169    k: usize,
170    elements: &[u8],
171    positions: &[[f64; 3]],
172    cn: &[f64],
173    config: &D4Config,
174    ang_to_bohr: f64,
175) -> f64 {
176    let r_ab = distance_bohr(positions, i, j, ang_to_bohr);
177    let r_bc = distance_bohr(positions, j, k, ang_to_bohr);
178    let r_ca = distance_bohr(positions, k, i, ang_to_bohr);
179
180    if r_ab < 1e-10 || r_bc < 1e-10 || r_ca < 1e-10 {
181        return 0.0;
182    }
183
184    let c6_ab = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
185    let c6_bc = dynamic_c6(elements[j], elements[k], cn[j], cn[k]);
186    let c6_ca = dynamic_c6(elements[k], elements[i], cn[k], cn[i]);
187
188    // C9 ≈ -sqrt(C6_AB * C6_BC * C6_CA) per Grimme D3 (JCP 132, 154104)
189    let c9 = -(c6_ab * c6_bc * c6_ca).abs().sqrt();
190
191    let cos_a = (r_ab * r_ab + r_ca * r_ca - r_bc * r_bc) / (2.0 * r_ab * r_ca);
192    let cos_b = (r_ab * r_ab + r_bc * r_bc - r_ca * r_ca) / (2.0 * r_ab * r_bc);
193    let cos_c = (r_bc * r_bc + r_ca * r_ca - r_ab * r_ab) / (2.0 * r_bc * r_ca);
194    let angular = 3.0 * cos_a * cos_b * cos_c + 1.0;
195    let r_prod = r_ab * r_bc * r_ca;
196
197    config.s9 * c9 * angular / r_prod.powi(3)
198}
199
200/// Compute numerical D4 gradient.
201pub fn compute_d4_gradient(
202    elements: &[u8],
203    positions: &[[f64; 3]],
204    config: &D4Config,
205) -> Vec<[f64; 3]> {
206    let n = elements.len();
207    let h = 1e-5;
208    let mut gradient = vec![[0.0; 3]; n];
209
210    for i in 0..n {
211        for d in 0..3 {
212            let mut pos_p = positions.to_vec();
213            let mut pos_m = positions.to_vec();
214            pos_p[i][d] += h;
215            pos_m[i][d] -= h;
216
217            let ep = compute_d4_energy(elements, &pos_p, config).total_energy;
218            let em = compute_d4_energy(elements, &pos_m, config).total_energy;
219
220            gradient[i][d] = (ep - em) / (2.0 * h);
221        }
222    }
223
224    gradient
225}
226
227fn distance_bohr(positions: &[[f64; 3]], i: usize, j: usize, ang_to_bohr: f64) -> f64 {
228    let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
229    let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
230    let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
231    (dx * dx + dy * dy + dz * dz).sqrt()
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_d4_energy_negative() {
240        let elements = [6, 6];
241        let pos = [[0.0, 0.0, 0.0], [3.5, 0.0, 0.0]];
242        let config = D4Config::default();
243        let result = compute_d4_energy(&elements, &pos, &config);
244        assert!(
245            result.total_energy < 0.0,
246            "D4 energy should be negative: {}",
247            result.total_energy
248        );
249    }
250
251    #[test]
252    fn test_d4_decays_with_distance() {
253        let elements = [6, 6];
254        let e_close = compute_d4_energy(
255            &elements,
256            &[[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]],
257            &D4Config::default(),
258        )
259        .total_energy;
260        let e_far = compute_d4_energy(
261            &elements,
262            &[[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]],
263            &D4Config::default(),
264        )
265        .total_energy;
266        assert!(
267            e_close.abs() > e_far.abs(),
268            "D4 should decay: close={}, far={}",
269            e_close,
270            e_far
271        );
272    }
273
274    #[test]
275    fn test_d4_three_body() {
276        let elements = [6, 6, 6];
277        let pos = [[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.25, 2.17, 0.0]];
278        let r2 = compute_d4_energy(
279            &elements,
280            &pos,
281            &D4Config {
282                three_body: false,
283                ..Default::default()
284            },
285        );
286        let r3 = compute_d4_energy(
287            &elements,
288            &pos,
289            &D4Config {
290                three_body: true,
291                ..Default::default()
292            },
293        );
294        assert!(
295            (r3.total_energy - r2.total_energy).abs() > 0.0,
296            "3-body should differ from 2-body"
297        );
298    }
299
300    #[test]
301    fn test_d4_gradient_finite() {
302        let elements = [6, 8, 1, 1];
303        let pos = [
304            [0.0, 0.0, 0.0],
305            [1.23, 0.0, 0.0],
306            [-0.6, 0.9, 0.0],
307            [-0.6, -0.9, 0.0],
308        ];
309        let grad = compute_d4_gradient(&elements, &pos, &D4Config::default());
310        for g in &grad {
311            for &d in g {
312                assert!(d.is_finite(), "Gradient contains NaN/Inf");
313            }
314        }
315    }
316}