Skip to main content

sci_form/dispersion/
dispersion.rs

1//! D4 Dispersion Energy and Gradients — Core
2//!
3//! Two-body BJ-damped dispersion with optional three-body ATM term.
4
5use super::params::{c8_from_c6, d4_coordination_number, dynamic_c6, get_d4_params};
6use serde::{Deserialize, Serialize};
7
8/// D4 configuration.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct D4Config {
11    /// s6 scaling factor (usually 1.0).
12    pub s6: f64,
13    /// s8 scaling factor.
14    pub s8: f64,
15    /// BJ damping a1.
16    pub a1: f64,
17    /// BJ damping a2 (Bohr).
18    pub a2: f64,
19    /// Whether to include three-body ATM term.
20    pub three_body: bool,
21    /// s9 scaling for ATM (usually 1.0).
22    pub s9: f64,
23}
24
25impl Default for D4Config {
26    fn default() -> Self {
27        Self {
28            s6: 1.0,
29            s8: 0.95,
30            a1: 0.45,
31            a2: 4.0,
32            three_body: false,
33            s9: 1.0,
34        }
35    }
36}
37
38/// Result of D4 dispersion calculation.
39///
40/// Energy terms are in Hartree; `total_kcal_mol` provides the
41/// convenience conversion. Two-body (`e2_body`) and three-body
42/// (`e3_body`) terms are reported separately for transparency.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct D4Result {
45    /// Two-body dispersion energy (Hartree).
46    pub e2_body: f64,
47    /// Three-body ATM energy (Hartree).
48    pub e3_body: f64,
49    /// Total dispersion energy (Hartree).
50    pub total_energy: f64,
51    /// Total in kcal/mol.
52    pub total_kcal_mol: f64,
53    /// Per-atom coordination numbers.
54    pub coordination_numbers: Vec<f64>,
55}
56
57/// Compute D4 dispersion energy.
58pub fn compute_d4_energy(elements: &[u8], positions: &[[f64; 3]], config: &D4Config) -> D4Result {
59    let n = elements.len();
60    let cn = d4_coordination_number(elements, positions);
61    let ang_to_bohr = 1.0 / 0.529177;
62
63    #[cfg(feature = "parallel")]
64    let e2 = {
65        use rayon::prelude::*;
66        (0..n)
67            .into_par_iter()
68            .map(|i| {
69                ((i + 1)..n)
70                    .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
71                    .sum::<f64>()
72            })
73            .sum::<f64>()
74    };
75
76    #[cfg(not(feature = "parallel"))]
77    let e2 = (0..n)
78        .map(|i| {
79            ((i + 1)..n)
80                .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
81                .sum::<f64>()
82        })
83        .sum::<f64>();
84
85    #[cfg(feature = "parallel")]
86    let e3 = if config.three_body && n >= 3 {
87        use rayon::prelude::*;
88        (0..n)
89            .into_par_iter()
90            .map(|i| {
91                let mut subtotal = 0.0;
92                for j in (i + 1)..n {
93                    for k in (j + 1)..n {
94                        subtotal +=
95                            triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
96                    }
97                }
98                subtotal
99            })
100            .sum::<f64>()
101    } else {
102        0.0
103    };
104
105    #[cfg(not(feature = "parallel"))]
106    let e3 = if config.three_body && n >= 3 {
107        let mut total = 0.0;
108        for i in 0..n {
109            for j in (i + 1)..n {
110                for k in (j + 1)..n {
111                    total += triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
112                }
113            }
114        }
115        total
116    } else {
117        0.0
118    };
119
120    let total = e2 + e3;
121    let hartree_to_kcal = 627.509;
122
123    D4Result {
124        e2_body: e2,
125        e3_body: e3,
126        total_energy: total,
127        total_kcal_mol: total * hartree_to_kcal,
128        coordination_numbers: cn,
129    }
130}
131
132fn pair_energy(
133    i: usize,
134    j: usize,
135    elements: &[u8],
136    positions: &[[f64; 3]],
137    cn: &[f64],
138    config: &D4Config,
139    ang_to_bohr: f64,
140) -> f64 {
141    let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
142    let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
143    let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
144    let r = (dx * dx + dy * dy + dz * dz).sqrt();
145
146    if r < 1e-10 {
147        return 0.0;
148    }
149
150    let c6 = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
151    let c8 = c8_from_c6(c6, elements[i], elements[j]);
152    let r0 = if c6 > 1e-10 { (c8 / c6).sqrt() } else { 5.0 };
153    let r_cut = config.a1 * r0 + config.a2;
154
155    let r6 = r.powi(6);
156    let damp6 = r6 / (r6 + r_cut.powi(6));
157    let term6 = -config.s6 * c6 / r6 * damp6;
158
159    let r8 = r.powi(8);
160    let damp8 = r8 / (r8 + r_cut.powi(8));
161    let term8 = -config.s8 * c8 / r8 * damp8;
162
163    term6 + term8
164}
165
166fn triple_energy(
167    i: usize,
168    j: usize,
169    k: usize,
170    elements: &[u8],
171    positions: &[[f64; 3]],
172    cn: &[f64],
173    config: &D4Config,
174    ang_to_bohr: f64,
175) -> f64 {
176    let r_ab = distance_bohr(positions, i, j, ang_to_bohr);
177    let r_bc = distance_bohr(positions, j, k, ang_to_bohr);
178    let r_ca = distance_bohr(positions, k, i, ang_to_bohr);
179
180    if r_ab < 1e-10 || r_bc < 1e-10 || r_ca < 1e-10 {
181        return 0.0;
182    }
183
184    let c6_ab = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
185    let c6_bc = dynamic_c6(elements[j], elements[k], cn[j], cn[k]);
186    let c6_ca = dynamic_c6(elements[k], elements[i], cn[k], cn[i]);
187
188    // C9 ≈ -sqrt(C6_AB * C6_BC * C6_CA) per Grimme D3 (JCP 132, 154104)
189    let c9 = -(c6_ab * c6_bc * c6_ca).abs().sqrt();
190
191    let cos_a = (r_ab * r_ab + r_ca * r_ca - r_bc * r_bc) / (2.0 * r_ab * r_ca);
192    let cos_b = (r_ab * r_ab + r_bc * r_bc - r_ca * r_ca) / (2.0 * r_ab * r_bc);
193    let cos_c = (r_bc * r_bc + r_ca * r_ca - r_ab * r_ab) / (2.0 * r_bc * r_ca);
194    let angular = 3.0 * cos_a * cos_b * cos_c + 1.0;
195    let r_prod = r_ab * r_bc * r_ca;
196
197    // ATM BJ-like damping per Grimme D3 (JCP 132, 154104, Eq. 6):
198    // f_damp = 1 / (1 + 6 * (R0_ABC / r_mean)^alpha)
199    // where R0_ABC = (4/3)*(R_cov_A + R_cov_B)^(1/3) * ... geometric mean of covalent radii
200    // and alpha = 14 (steep damping exponent)
201    let r_cov_i = get_d4_params(elements[i]).r_cov;
202    let r_cov_j = get_d4_params(elements[j]).r_cov;
203    let r_cov_k = get_d4_params(elements[k]).r_cov;
204    let r0_ab = (4.0 / 3.0) * (r_cov_i + r_cov_j);
205    let r0_bc = (4.0 / 3.0) * (r_cov_j + r_cov_k);
206    let r0_ca = (4.0 / 3.0) * (r_cov_k + r_cov_i);
207    let r0_prod = r0_ab * r0_bc * r0_ca;
208    let r9 = r_prod.powi(3);
209    let r0_9 = r0_prod.powi(3);
210    let fdamp = 1.0 / (1.0 + 6.0 * (r0_9 / r9));
211
212    config.s9 * c9 * angular / r9 * fdamp
213}
214
215/// Compute numerical D4 gradient.
216pub fn compute_d4_gradient(
217    elements: &[u8],
218    positions: &[[f64; 3]],
219    config: &D4Config,
220) -> Vec<[f64; 3]> {
221    let n = elements.len();
222    let h = 1e-5;
223    let mut gradient = vec![[0.0; 3]; n];
224
225    for i in 0..n {
226        for d in 0..3 {
227            let mut pos_p = positions.to_vec();
228            let mut pos_m = positions.to_vec();
229            pos_p[i][d] += h;
230            pos_m[i][d] -= h;
231
232            let ep = compute_d4_energy(elements, &pos_p, config).total_energy;
233            let em = compute_d4_energy(elements, &pos_m, config).total_energy;
234
235            gradient[i][d] = (ep - em) / (2.0 * h);
236        }
237    }
238
239    gradient
240}
241
242fn distance_bohr(positions: &[[f64; 3]], i: usize, j: usize, ang_to_bohr: f64) -> f64 {
243    let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
244    let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
245    let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
246    (dx * dx + dy * dy + dz * dz).sqrt()
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_d4_energy_negative() {
255        let elements = [6, 6];
256        let pos = [[0.0, 0.0, 0.0], [3.5, 0.0, 0.0]];
257        let config = D4Config::default();
258        let result = compute_d4_energy(&elements, &pos, &config);
259        assert!(
260            result.total_energy < 0.0,
261            "D4 energy should be negative: {}",
262            result.total_energy
263        );
264    }
265
266    #[test]
267    fn test_d4_decays_with_distance() {
268        let elements = [6, 6];
269        let e_close = compute_d4_energy(
270            &elements,
271            &[[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]],
272            &D4Config::default(),
273        )
274        .total_energy;
275        let e_far = compute_d4_energy(
276            &elements,
277            &[[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]],
278            &D4Config::default(),
279        )
280        .total_energy;
281        assert!(
282            e_close.abs() > e_far.abs(),
283            "D4 should decay: close={}, far={}",
284            e_close,
285            e_far
286        );
287    }
288
289    #[test]
290    fn test_d4_three_body() {
291        let elements = [6, 6, 6];
292        let pos = [[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.25, 2.17, 0.0]];
293        let r2 = compute_d4_energy(
294            &elements,
295            &pos,
296            &D4Config {
297                three_body: false,
298                ..Default::default()
299            },
300        );
301        let r3 = compute_d4_energy(
302            &elements,
303            &pos,
304            &D4Config {
305                three_body: true,
306                ..Default::default()
307            },
308        );
309        assert!(
310            (r3.total_energy - r2.total_energy).abs() > 0.0,
311            "3-body should differ from 2-body"
312        );
313    }
314
315    #[test]
316    fn test_d4_gradient_finite() {
317        let elements = [6, 8, 1, 1];
318        let pos = [
319            [0.0, 0.0, 0.0],
320            [1.23, 0.0, 0.0],
321            [-0.6, 0.9, 0.0],
322            [-0.6, -0.9, 0.0],
323        ];
324        let grad = compute_d4_gradient(&elements, &pos, &D4Config::default());
325        for g in &grad {
326            for &d in g {
327                assert!(d.is_finite(), "Gradient contains NaN/Inf");
328            }
329        }
330    }
331}