1use super::d4_data::*;
9
10const D4_S6: f64 = 1.0;
12const D4_S8: f64 = 2.7;
13const D4_A1: f64 = 0.52;
14const D4_A2: f64 = 5.0;
15const D4_S9: f64 = 5.0;
16
17const D4_WF: f64 = 6.0;
19
20const D4_GA: f64 = 3.0;
22const D4_GC: f64 = 2.0;
23
24const D4_CN_CUTOFF: f64 = 25.0;
26
27const D4_DISP2_CUTOFF: f64 = 50.0;
29
30pub struct D4Model {
36 pub nat: usize,
37 pub elements: Vec<u8>,
38 pub cn: Vec<f64>,
40 dispmat_flat: Vec<f64>,
44 mref: usize,
46 #[allow(dead_code)]
49 scaled_alpha: Vec<Vec<Vec<f64>>>,
50 c6_ref_flat: Vec<f64>,
54 elem_types: Vec<u8>,
56 atom_to_type: Vec<usize>,
58}
59
60pub struct D4Weights {
62 pub gwvec: Vec<Vec<f64>>,
64 pub dgwdq: Vec<Vec<f64>>,
66}
67
68impl D4Model {
69 pub fn new(elements: &[u8], positions: &[[f64; 3]]) -> Self {
74 let nat = elements.len();
75 let mref = MAX_REF;
76
77 let cn = compute_d4_cn(elements, positions);
79
80 let mut elem_types: Vec<u8> = Vec::new();
82 let mut atom_to_type = vec![0usize; nat];
83 for (iat, &z) in elements.iter().enumerate() {
84 if let Some(pos) = elem_types.iter().position(|&e| e == z) {
85 atom_to_type[iat] = pos;
86 } else {
87 atom_to_type[iat] = elem_types.len();
88 elem_types.push(z);
89 }
90 }
91
92 let scaled_alpha = compute_scaled_alpha(elements);
94
95 let n_types = elem_types.len();
97 let mut c6_ref_flat = vec![0.0f64; n_types * n_types * mref * mref];
98 for (it, &zi) in elem_types.iter().enumerate() {
99 let nref_i = get_nref(zi);
100 for (jt, &zj) in elem_types.iter().enumerate() {
101 let nref_j = get_nref(zj);
102 for iref in 0..nref_i {
103 let alpha_i = &scaled_alpha[it][iref];
104 if alpha_i.iter().all(|&v| v == 0.0) {
105 continue;
106 }
107 for jref in 0..nref_j {
108 let alpha_j = &scaled_alpha[jt][jref];
109 if alpha_j.iter().all(|&v| v == 0.0) {
110 continue;
111 }
112 let mut c6 = 0.0;
114 for k in 0..NFREQ {
115 c6 += CP_WEIGHTS[k] * alpha_i[k] * alpha_j[k];
116 }
117 c6 *= 3.0 / std::f64::consts::PI;
118 let idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
119 c6_ref_flat[idx] = c6;
120 }
121 }
122 }
123 }
124
125 let mut dispmat_flat = vec![0.0f64; mref * nat * mref * nat];
127 let cutoff2 = D4_DISP2_CUTOFF * D4_DISP2_CUTOFF;
128
129 for iat in 0..nat {
130 let iz = elements[iat];
131 let it = atom_to_type[iat];
132 let nref_i = get_nref(iz);
133 for jat in 0..=iat {
134 let jz = elements[jat];
135 let jt = atom_to_type[jat];
136 let nref_j = get_nref(jz);
137
138 let dx = positions[iat][0] - positions[jat][0];
139 let dy = positions[iat][1] - positions[jat][1];
140 let dz = positions[iat][2] - positions[jat][2];
141 let r2 = dx * dx + dy * dy + dz * dz;
142
143 if r2 > cutoff2 || r2 < 1e-15 {
144 continue;
145 }
146
147 let rrij = 3.0 * R4R2[iz as usize - 1] * R4R2[jz as usize - 1];
148 let r0ij = D4_A1 * rrij.sqrt() + D4_A2;
149 let t6 = 1.0 / (r2.powi(3) + r0ij.powi(6));
150 let t8 = 1.0 / (r2.powi(4) + r0ij.powi(8));
151 let de = -(D4_S6 * t6 + D4_S8 * rrij * t8);
152
153 for iref in 0..nref_i {
154 for jref in 0..nref_j {
155 let c6_idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
156 let c6 = c6_ref_flat[c6_idx];
157 let val = de * c6;
158
159 let idx_ij = ((iref * nat + iat) * mref + jref) * nat + jat;
160 let idx_ji = ((jref * nat + jat) * mref + iref) * nat + iat;
161 dispmat_flat[idx_ij] = val;
162 dispmat_flat[idx_ji] = val;
163 }
164 }
165 }
166 }
167
168 D4Model {
169 nat,
170 elements: elements.to_vec(),
171 cn,
172 dispmat_flat,
173 mref,
174 scaled_alpha,
175 c6_ref_flat,
176 elem_types,
177 atom_to_type,
178 }
179 }
180
181 pub fn weight_references(&self, charges: &[f64]) -> D4Weights {
187 let nat = self.nat;
188 let mut gwvec = vec![vec![0.0f64; MAX_REF]; nat];
189 let mut dgwdq = vec![vec![0.0f64; MAX_REF]; nat];
190
191 for iat in 0..nat {
192 let z = self.elements[iat];
193 let zi = z as usize;
194 if zi == 0 || zi > MAX_ELEM {
195 continue;
196 }
197 let nref = get_nref(z);
198 if nref == 0 {
199 continue;
200 }
201
202 let cn_val = self.cn[iat];
203 let q_val = charges[iat];
204
205 let zeff_i = EFFECTIVE_NUCLEAR_CHARGE[zi - 1];
206 let gi = CHEMICAL_HARDNESS[zi - 1] * D4_GC;
207
208 let mut ngw = vec![0usize; nref];
211 {
212 let max_cn_int: usize = 19;
213 let mut cnc = vec![0usize; max_cn_int + 1];
214 cnc[0] = 1; for iref in 0..nref {
216 let rcn = get_refcn(z, iref);
217 let icn = (rcn.round() as usize).min(max_cn_int);
218 cnc[icn] += 1;
219 }
220 for iref in 0..nref {
221 let rcn = get_refcn(z, iref);
222 let icn = (rcn.round() as usize).min(max_cn_int);
223 let c = cnc[icn];
224 ngw[iref] = c * (c + 1) / 2;
225 }
226 }
227
228 let mut covcn = vec![0.0f64; nref];
230 let mut refq = vec![0.0f64; nref];
231 for iref in 0..nref {
232 covcn[iref] = get_refcovcn(z, iref);
233 refq[iref] = get_refq(z, iref);
234 }
235
236 let mut norm = 0.0f64;
238 for iref in 0..nref {
239 for igw in 1..=ngw[iref] {
240 let wf = igw as f64 * D4_WF;
241 norm += weight_cn(wf, cn_val, covcn[iref]);
242 }
243 }
244 let norm_inv = if norm.abs() > 1e-150 { 1.0 / norm } else { 0.0 };
245
246 for iref in 0..nref {
248 let mut expw = 0.0f64;
249 for igw in 1..=ngw[iref] {
250 let wf = igw as f64 * D4_WF;
251 expw += weight_cn(wf, cn_val, covcn[iref]);
252 }
253 let mut gwk = expw * norm_inv;
254
255 if !gwk.is_finite() || norm_inv == 0.0 {
257 let max_covcn = covcn[..nref]
258 .iter()
259 .cloned()
260 .fold(f64::NEG_INFINITY, f64::max);
261 gwk = if (max_covcn - covcn[iref]).abs() < 1e-12 {
262 1.0
263 } else {
264 0.0
265 };
266 }
267
268 let z_val = zeta(D4_GA, gi, refq[iref] + zeff_i, q_val + zeff_i);
269 let dz_val = dzeta(D4_GA, gi, refq[iref] + zeff_i, q_val + zeff_i);
270
271 gwvec[iat][iref] = gwk * z_val;
272 dgwdq[iat][iref] = gwk * dz_val;
273 }
274 }
275
276 D4Weights { gwvec, dgwdq }
277 }
278
279 pub fn get_potential(&self, weights: &D4Weights) -> Vec<f64> {
286 let nat = self.nat;
287 let mref = self.mref;
288 let mut vat = vec![0.0f64; nat];
289
290 for iat in 0..nat {
291 let nref_i = get_nref(self.elements[iat]);
292 let mut vvec = vec![0.0f64; nref_i];
294 for iref in 0..nref_i {
295 for jat in 0..nat {
296 let nref_j = get_nref(self.elements[jat]);
297 for jref in 0..nref_j {
298 let idx = ((iref * nat + iat) * mref + jref) * nat + jat;
299 vvec[iref] += self.dispmat_flat[idx] * weights.gwvec[jat][jref];
300 }
301 }
302 }
303 for iref in 0..nref_i {
305 vat[iat] += vvec[iref] * weights.dgwdq[iat][iref];
306 }
307 }
308
309 vat
310 }
311
312 pub fn get_energy(&self, weights: &D4Weights) -> f64 {
316 let nat = self.nat;
317 let mref = self.mref;
318 let mut energy = 0.0f64;
319
320 for iat in 0..nat {
321 let nref_i = get_nref(self.elements[iat]);
322 let mut vvec = vec![0.0f64; nref_i];
324 for iref in 0..nref_i {
325 for jat in 0..nat {
326 let nref_j = get_nref(self.elements[jat]);
327 for jref in 0..nref_j {
328 let idx = ((iref * nat + iat) * mref + jref) * nat + jat;
329 vvec[iref] += self.dispmat_flat[idx] * weights.gwvec[jat][jref];
330 }
331 }
332 }
333 for iref in 0..nref_i {
334 energy += 0.5 * vvec[iref] * weights.gwvec[iat][iref];
335 }
336 }
337
338 energy
339 }
340
341 pub fn get_atm_energy(&self, positions: &[[f64; 3]]) -> f64 {
346 let nat = self.nat;
347 if nat < 3 || D4_S9.abs() < 1e-15 {
348 return 0.0;
349 }
350
351 let zero_charges = vec![0.0f64; nat];
353 let w0 = self.weight_references(&zero_charges);
354 let c6 = self.get_c6_matrix(&w0);
355
356 let cutoff2 = D4_CN_CUTOFF * D4_CN_CUTOFF;
357 let alp3 = 16.0 / 3.0;
358 let mut energy = 0.0f64;
359
360 for iat in 0..nat {
361 let iz = self.elements[iat] as usize;
362 for jat in 0..iat {
363 let jz = self.elements[jat] as usize;
364 let c6ij = c6[jat * nat + iat];
365 let r0ij = D4_A1 * (3.0 * R4R2[iz - 1] * R4R2[jz - 1]).sqrt() + D4_A2;
366
367 let vij = [
368 positions[jat][0] - positions[iat][0],
369 positions[jat][1] - positions[iat][1],
370 positions[jat][2] - positions[iat][2],
371 ];
372 let r2ij = vij[0] * vij[0] + vij[1] * vij[1] + vij[2] * vij[2];
373 if r2ij > cutoff2 || r2ij < 1e-15 {
374 continue;
375 }
376
377 for kat in 0..jat {
378 let kz = self.elements[kat] as usize;
379 let c6ik = c6[kat * nat + iat];
380 let c6jk = c6[kat * nat + jat];
381 let c9 = -D4_S9 * (c6ij * c6ik * c6jk).abs().sqrt();
382
383 let r0ik = D4_A1 * (3.0 * R4R2[kz - 1] * R4R2[iz - 1]).sqrt() + D4_A2;
384 let r0jk = D4_A1 * (3.0 * R4R2[kz - 1] * R4R2[jz - 1]).sqrt() + D4_A2;
385 let r0 = r0ij * r0ik * r0jk;
386
387 let triple = triple_scale(iat, jat, kat);
389
390 let vik = [
391 positions[kat][0] - positions[iat][0],
392 positions[kat][1] - positions[iat][1],
393 positions[kat][2] - positions[iat][2],
394 ];
395 let r2ik = vik[0] * vik[0] + vik[1] * vik[1] + vik[2] * vik[2];
396 if r2ik > cutoff2 || r2ik < 1e-15 {
397 continue;
398 }
399
400 let vjk = [vik[0] - vij[0], vik[1] - vij[1], vik[2] - vij[2]];
401 let r2jk = vjk[0] * vjk[0] + vjk[1] * vjk[1] + vjk[2] * vjk[2];
402 if r2jk > cutoff2 || r2jk < 1e-15 {
403 continue;
404 }
405
406 let r2 = r2ij * r2ik * r2jk;
407 let r1 = r2.sqrt();
408 let r3 = r2 * r1;
409 let r5 = r3 * r2;
410
411 let fdmp = 1.0 / (1.0 + 6.0 * (r0 / r1).powf(alp3));
412 let ang =
413 0.375 * (r2ij + r2jk - r2ik) * (r2ij - r2jk + r2ik) * (-r2ij + r2jk + r2ik)
414 / r5
415 + 1.0 / r3;
416
417 let rr = ang * fdmp;
418 let de = rr * c9 * triple / 6.0;
419
420 energy -= 6.0 * de;
423 }
424 }
425 }
426
427 energy
428 }
429
430 fn get_c6_matrix(&self, weights: &D4Weights) -> Vec<f64> {
432 let nat = self.nat;
433 let n_types = self.elem_types.len();
434 let mref = self.mref;
435 let mut c6 = vec![0.0f64; nat * nat];
436
437 for iat in 0..nat {
438 let it = self.atom_to_type[iat];
439 let nref_i = get_nref(self.elements[iat]);
440 for jat in 0..nat {
441 let jt = self.atom_to_type[jat];
442 let nref_j = get_nref(self.elements[jat]);
443 let mut val = 0.0;
444 for iref in 0..nref_i {
445 for jref in 0..nref_j {
446 let c6_idx = (it * n_types + jt) * mref * mref + iref * mref + jref;
447 val += weights.gwvec[iat][iref]
448 * weights.gwvec[jat][jref]
449 * self.c6_ref_flat[c6_idx];
450 }
451 }
452 c6[jat * nat + iat] = val;
453 }
454 }
455
456 c6
457 }
458}
459
460fn get_nref(z: u8) -> usize {
464 let zi = z as usize;
465 if zi == 0 || zi > MAX_ELEM {
466 return 0;
467 }
468 REFN[zi - 1]
469}
470
471fn get_refcn(z: u8, iref: usize) -> f64 {
473 let zi = z as usize;
474 if zi == 0 || zi > MAX_ELEM {
475 return 0.0;
476 }
477 REFCN[(zi - 1) * MAX_REF + iref]
478}
479
480fn get_refcovcn(z: u8, iref: usize) -> f64 {
482 let zi = z as usize;
483 if zi == 0 || zi > MAX_ELEM {
484 return 0.0;
485 }
486 REFCOVCN[(zi - 1) * MAX_REF + iref]
487}
488
489fn get_refq(z: u8, iref: usize) -> f64 {
491 let zi = z as usize;
492 if zi == 0 || zi > MAX_ELEM {
493 return 0.0;
494 }
495 REFQ_GFN2[(zi - 1) * MAX_REF + iref]
496}
497
498fn weight_cn(wf: f64, cn: f64, cnref: f64) -> f64 {
500 let d = cn - cnref;
501 (-wf * d * d).exp()
502}
503
504fn zeta(a: f64, c: f64, qref: f64, qmod: f64) -> f64 {
508 if qmod < 0.0 {
509 return a.exp();
510 }
511 (a * (1.0 - (c * (1.0 - qref / qmod)).exp())).exp()
512}
513
514fn dzeta(a: f64, c: f64, qref: f64, qmod: f64) -> f64 {
516 if qmod < 0.0 {
517 return 0.0;
518 }
519 let z = zeta(a, c, qref, qmod);
520 -a * c * (c * (1.0 - qref / qmod)).exp() * z * qref / (qmod * qmod)
521}
522
523fn triple_scale(ii: usize, jj: usize, kk: usize) -> f64 {
525 if ii == jj {
526 if ii == kk {
527 1.0 / 6.0
528 } else {
529 0.5
530 }
531 } else if ii != kk && jj != kk {
532 1.0
533 } else {
534 0.5
535 }
536}
537
538fn compute_d4_cn(elements: &[u8], positions: &[[f64; 3]]) -> Vec<f64> {
544 let nat = elements.len();
545 let mut cn = vec![0.0f64; nat];
546 let cutoff2 = D4_CN_CUTOFF * D4_CN_CUTOFF;
547
548 let ka = 7.5f64;
550
551 for iat in 0..nat {
552 let zi = elements[iat] as usize;
553 if zi == 0 || zi > MAX_ELEM {
554 continue;
555 }
556 let rcov_i = COVRAD_D3[zi - 1];
557
558 for jat in 0..nat {
559 if iat == jat {
560 continue;
561 }
562 let zj = elements[jat] as usize;
563 if zj == 0 || zj > MAX_ELEM {
564 continue;
565 }
566 let rcov_j = COVRAD_D3[zj - 1];
567
568 let dx = positions[iat][0] - positions[jat][0];
569 let dy = positions[iat][1] - positions[jat][1];
570 let dz = positions[iat][2] - positions[jat][2];
571 let r2 = dx * dx + dy * dy + dz * dz;
572
573 if r2 > cutoff2 || r2 < 1e-15 {
574 continue;
575 }
576
577 let r = r2.sqrt();
578 let rcov_sum = rcov_i + rcov_j;
579
580 let cn_val = 0.5 * erf(-ka * (r / rcov_sum - 1.0));
582 cn[iat] += cn_val + 0.5;
583 }
584 }
585
586 cn
587}
588
589fn erf(x: f64) -> f64 {
591 let sign = if x >= 0.0 { 1.0 } else { -1.0 };
592 let x = x.abs();
593 let t = 1.0 / (1.0 + 0.3275911 * x);
594 let t2 = t * t;
595 let t3 = t2 * t;
596 let t4 = t3 * t;
597 let t5 = t4 * t;
598 let poly =
599 0.254829592 * t - 0.284496736 * t2 + 1.421413741 * t3 - 1.453152027 * t4 + 1.061405429 * t5;
600 sign * (1.0 - poly * (-x * x).exp())
601}
602
603fn compute_scaled_alpha(elements: &[u8]) -> Vec<Vec<Vec<f64>>> {
608 let mut elem_types: Vec<u8> = Vec::new();
610 for &z in elements {
611 if !elem_types.contains(&z) {
612 elem_types.push(z);
613 }
614 }
615
616 let mut result = Vec::with_capacity(elem_types.len());
617
618 for &z in &elem_types {
619 let zi = z as usize;
620 let nref = get_nref(z);
621 let mut alphas_for_elem = vec![vec![0.0f64; NFREQ]; MAX_REF];
622
623 for iref in 0..nref {
624 let base_idx = (zi - 1) * MAX_REF + iref;
625 let is_sys = REFSYS[base_idx]; let hc = HCOUNT[base_idx];
627 let asc = ASCALE[base_idx];
628 let rh = REFH[base_idx];
629
630 if is_sys == 0 {
631 let alpha_base = (zi - 1) * MAX_REF * NFREQ + iref * NFREQ;
633 for k in 0..NFREQ {
634 alphas_for_elem[iref][k] = (asc * ALPHAIW[alpha_base + k]).max(0.0);
635 }
636 continue;
637 }
638
639 let ss = if is_sys <= MAX_SEC {
641 SSCALE[is_sys - 1]
642 } else {
643 0.0
644 };
645 let iz_sec = EFFECTIVE_NUCLEAR_CHARGE[is_sys - 1];
646 let eta_sec = CHEMICAL_HARDNESS[is_sys - 1] * D4_GC;
647 let z_scale = zeta(D4_GA, eta_sec, iz_sec, rh + iz_sec);
648
649 let sec_base = (is_sys - 1) * NFREQ;
650 let alpha_base = (zi - 1) * MAX_REF * NFREQ + iref * NFREQ;
651 for k in 0..NFREQ {
652 let sec_val = if is_sys <= MAX_SEC && sec_base + k < SECAIW.len() {
653 ss * SECAIW[sec_base + k] * z_scale
654 } else {
655 0.0
656 };
657 alphas_for_elem[iref][k] =
658 (asc * (ALPHAIW[alpha_base + k] - hc * sec_val)).max(0.0);
659 }
660 }
661
662 result.push(alphas_for_elem);
663 }
664
665 result
666}
667
668#[cfg(test)]
669mod tests {
670 use super::*;
671
672 #[test]
673 fn test_water_d4_cn() {
674 let elements = [8u8, 1, 1];
676 let positions = [
677 [0.0, 0.0, 0.221228620],
678 [0.0, 1.430453160, -0.885762480],
679 [0.0, -1.430453160, -0.885762480],
680 ];
681 let cn = compute_d4_cn(&elements, &positions);
682 assert!(
686 cn[0] > 1.0 && cn[0] < 2.5,
687 "O CN={:.6}, expected 1.0–2.5",
688 cn[0]
689 );
690 assert!(
691 cn[1] > 0.5 && cn[1] < 1.5,
692 "H CN={:.6}, expected 0.5–1.5",
693 cn[1]
694 );
695 assert!(
696 cn[2] > 0.5 && cn[2] < 1.5,
697 "H CN={:.6}, expected 0.5–1.5",
698 cn[2]
699 );
700 }
701
702 #[test]
703 fn test_water_d4_potential_at_zero_charges() {
704 let elements = [8u8, 1, 1];
705 let positions = [
706 [0.0, 0.0, 0.221228620],
707 [0.0, 1.430453160, -0.885762480],
708 [0.0, -1.430453160, -0.885762480],
709 ];
710 let model = D4Model::new(&elements, &positions);
711 let charges = [0.0, 0.0, 0.0];
712 let w = model.weight_references(&charges);
713 let vat = model.get_potential(&w);
714 let e_sc = model.get_energy(&w);
715
716 eprintln!("D4 vat (q=0): {:?}", vat);
720 eprintln!("D4 SC energy (q=0): {:.10e}", e_sc);
721
722 assert!(vat[0].abs() > 1e-6, "O vat should be non-zero");
724 assert!((vat[1] - vat[2]).abs() < 1e-12, "H vat should be symmetric");
725 assert!(e_sc < 0.0, "SC energy should be negative");
726 assert!(
727 (e_sc - (-2.506e-4)).abs() < 5e-5,
728 "SC energy should match Python: got {:.6e}",
729 e_sc
730 );
731 }
732}