1use super::params::{c8_from_c6, d4_coordination_number, dynamic_c6};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct D4Config {
11 pub s6: f64,
13 pub s8: f64,
15 pub a1: f64,
17 pub a2: f64,
19 pub three_body: bool,
21 pub s9: f64,
23}
24
25impl Default for D4Config {
26 fn default() -> Self {
27 Self {
28 s6: 1.0,
29 s8: 0.95,
30 a1: 0.45,
31 a2: 4.0,
32 three_body: false,
33 s9: 1.0,
34 }
35 }
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct D4Result {
45 pub e2_body: f64,
47 pub e3_body: f64,
49 pub total_energy: f64,
51 pub total_kcal_mol: f64,
53 pub coordination_numbers: Vec<f64>,
55}
56
57pub fn compute_d4_energy(elements: &[u8], positions: &[[f64; 3]], config: &D4Config) -> D4Result {
59 let n = elements.len();
60 let cn = d4_coordination_number(elements, positions);
61 let ang_to_bohr = 1.0 / 0.529177;
62
63 #[cfg(feature = "parallel")]
64 let e2 = {
65 use rayon::prelude::*;
66 (0..n)
67 .into_par_iter()
68 .map(|i| {
69 ((i + 1)..n)
70 .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
71 .sum::<f64>()
72 })
73 .sum::<f64>()
74 };
75
76 #[cfg(not(feature = "parallel"))]
77 let e2 = (0..n)
78 .map(|i| {
79 ((i + 1)..n)
80 .map(|j| pair_energy(i, j, elements, positions, &cn, config, ang_to_bohr))
81 .sum::<f64>()
82 })
83 .sum::<f64>();
84
85 #[cfg(feature = "parallel")]
86 let e3 = if config.three_body && n >= 3 {
87 use rayon::prelude::*;
88 (0..n)
89 .into_par_iter()
90 .map(|i| {
91 let mut subtotal = 0.0;
92 for j in (i + 1)..n {
93 for k in (j + 1)..n {
94 subtotal +=
95 triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
96 }
97 }
98 subtotal
99 })
100 .sum::<f64>()
101 } else {
102 0.0
103 };
104
105 #[cfg(not(feature = "parallel"))]
106 let e3 = if config.three_body && n >= 3 {
107 let mut total = 0.0;
108 for i in 0..n {
109 for j in (i + 1)..n {
110 for k in (j + 1)..n {
111 total += triple_energy(i, j, k, elements, positions, &cn, config, ang_to_bohr);
112 }
113 }
114 }
115 total
116 } else {
117 0.0
118 };
119
120 let total = e2 + e3;
121 let hartree_to_kcal = 627.509;
122
123 D4Result {
124 e2_body: e2,
125 e3_body: e3,
126 total_energy: total,
127 total_kcal_mol: total * hartree_to_kcal,
128 coordination_numbers: cn,
129 }
130}
131
132fn pair_energy(
133 i: usize,
134 j: usize,
135 elements: &[u8],
136 positions: &[[f64; 3]],
137 cn: &[f64],
138 config: &D4Config,
139 ang_to_bohr: f64,
140) -> f64 {
141 let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
142 let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
143 let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
144 let r = (dx * dx + dy * dy + dz * dz).sqrt();
145
146 if r < 1e-10 {
147 return 0.0;
148 }
149
150 let c6 = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
151 let c8 = c8_from_c6(c6, elements[i], elements[j]);
152 let r0 = if c6 > 1e-10 { (c8 / c6).sqrt() } else { 5.0 };
153 let r_cut = config.a1 * r0 + config.a2;
154
155 let r6 = r.powi(6);
156 let damp6 = r6 / (r6 + r_cut.powi(6));
157 let term6 = -config.s6 * c6 / r6 * damp6;
158
159 let r8 = r.powi(8);
160 let damp8 = r8 / (r8 + r_cut.powi(8));
161 let term8 = -config.s8 * c8 / r8 * damp8;
162
163 term6 + term8
164}
165
166fn triple_energy(
167 i: usize,
168 j: usize,
169 k: usize,
170 elements: &[u8],
171 positions: &[[f64; 3]],
172 cn: &[f64],
173 config: &D4Config,
174 ang_to_bohr: f64,
175) -> f64 {
176 let r_ab = distance_bohr(positions, i, j, ang_to_bohr);
177 let r_bc = distance_bohr(positions, j, k, ang_to_bohr);
178 let r_ca = distance_bohr(positions, k, i, ang_to_bohr);
179
180 if r_ab < 1e-10 || r_bc < 1e-10 || r_ca < 1e-10 {
181 return 0.0;
182 }
183
184 let c6_ab = dynamic_c6(elements[i], elements[j], cn[i], cn[j]);
185 let c6_bc = dynamic_c6(elements[j], elements[k], cn[j], cn[k]);
186 let c6_ca = dynamic_c6(elements[k], elements[i], cn[k], cn[i]);
187
188 let c9 = -(c6_ab * c6_bc * c6_ca).abs().sqrt();
190
191 let cos_a = (r_ab * r_ab + r_ca * r_ca - r_bc * r_bc) / (2.0 * r_ab * r_ca);
192 let cos_b = (r_ab * r_ab + r_bc * r_bc - r_ca * r_ca) / (2.0 * r_ab * r_bc);
193 let cos_c = (r_bc * r_bc + r_ca * r_ca - r_ab * r_ab) / (2.0 * r_bc * r_ca);
194 let angular = 3.0 * cos_a * cos_b * cos_c + 1.0;
195 let r_prod = r_ab * r_bc * r_ca;
196
197 config.s9 * c9 * angular / r_prod.powi(3)
198}
199
200pub fn compute_d4_gradient(
202 elements: &[u8],
203 positions: &[[f64; 3]],
204 config: &D4Config,
205) -> Vec<[f64; 3]> {
206 let n = elements.len();
207 let h = 1e-5;
208 let mut gradient = vec![[0.0; 3]; n];
209
210 for i in 0..n {
211 for d in 0..3 {
212 let mut pos_p = positions.to_vec();
213 let mut pos_m = positions.to_vec();
214 pos_p[i][d] += h;
215 pos_m[i][d] -= h;
216
217 let ep = compute_d4_energy(elements, &pos_p, config).total_energy;
218 let em = compute_d4_energy(elements, &pos_m, config).total_energy;
219
220 gradient[i][d] = (ep - em) / (2.0 * h);
221 }
222 }
223
224 gradient
225}
226
227fn distance_bohr(positions: &[[f64; 3]], i: usize, j: usize, ang_to_bohr: f64) -> f64 {
228 let dx = (positions[i][0] - positions[j][0]) * ang_to_bohr;
229 let dy = (positions[i][1] - positions[j][1]) * ang_to_bohr;
230 let dz = (positions[i][2] - positions[j][2]) * ang_to_bohr;
231 (dx * dx + dy * dy + dz * dz).sqrt()
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_d4_energy_negative() {
240 let elements = [6, 6];
241 let pos = [[0.0, 0.0, 0.0], [3.5, 0.0, 0.0]];
242 let config = D4Config::default();
243 let result = compute_d4_energy(&elements, &pos, &config);
244 assert!(
245 result.total_energy < 0.0,
246 "D4 energy should be negative: {}",
247 result.total_energy
248 );
249 }
250
251 #[test]
252 fn test_d4_decays_with_distance() {
253 let elements = [6, 6];
254 let e_close = compute_d4_energy(
255 &elements,
256 &[[0.0, 0.0, 0.0], [3.0, 0.0, 0.0]],
257 &D4Config::default(),
258 )
259 .total_energy;
260 let e_far = compute_d4_energy(
261 &elements,
262 &[[0.0, 0.0, 0.0], [10.0, 0.0, 0.0]],
263 &D4Config::default(),
264 )
265 .total_energy;
266 assert!(
267 e_close.abs() > e_far.abs(),
268 "D4 should decay: close={}, far={}",
269 e_close,
270 e_far
271 );
272 }
273
274 #[test]
275 fn test_d4_three_body() {
276 let elements = [6, 6, 6];
277 let pos = [[0.0, 0.0, 0.0], [2.5, 0.0, 0.0], [1.25, 2.17, 0.0]];
278 let r2 = compute_d4_energy(
279 &elements,
280 &pos,
281 &D4Config {
282 three_body: false,
283 ..Default::default()
284 },
285 );
286 let r3 = compute_d4_energy(
287 &elements,
288 &pos,
289 &D4Config {
290 three_body: true,
291 ..Default::default()
292 },
293 );
294 assert!(
295 (r3.total_energy - r2.total_energy).abs() > 0.0,
296 "3-body should differ from 2-body"
297 );
298 }
299
300 #[test]
301 fn test_d4_gradient_finite() {
302 let elements = [6, 8, 1, 1];
303 let pos = [
304 [0.0, 0.0, 0.0],
305 [1.23, 0.0, 0.0],
306 [-0.6, 0.9, 0.0],
307 [-0.6, -0.9, 0.0],
308 ];
309 let grad = compute_d4_gradient(&elements, &pos, &D4Config::default());
310 for g in &grad {
311 for &d in g {
312 assert!(d.is_finite(), "Gradient contains NaN/Inf");
313 }
314 }
315 }
316}