1use super::params::{count_pm3_electrons, get_pm3_params, num_pm3_basis_functions};
11use nalgebra::DMatrix;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Pm3Result {
17 pub orbital_energies: Vec<f64>,
19 pub electronic_energy: f64,
21 pub nuclear_repulsion: f64,
23 pub total_energy: f64,
25 pub heat_of_formation: f64,
27 pub n_basis: usize,
29 pub n_electrons: usize,
31 pub homo_energy: f64,
33 pub lumo_energy: f64,
35 pub gap: f64,
37 pub mulliken_charges: Vec<f64>,
39 pub scf_iterations: usize,
41 pub converged: bool,
43}
44
45pub(crate) const EV_TO_KCAL: f64 = 23.0605;
46pub(crate) const EV_PER_HARTREE: f64 = 27.2114;
47pub(crate) const BOHR_TO_ANGSTROM: f64 = 0.529177;
48pub(crate) const ANGSTROM_TO_BOHR: f64 = 1.0 / BOHR_TO_ANGSTROM;
49pub(crate) const PM3_GAMMA_FLOOR_BOHR: f64 = 0.5;
50
51fn pm3_isolated_atom_energy(z: u8, p: &super::params::Pm3Params) -> f64 {
57 match z {
58 1 => p.uss,
60 _ => {
64 let n_val = p.core_charge;
65 let n_p = (n_val - 2.0).max(0.0);
66 let e_one = 2.0 * p.uss + n_p * p.upp;
67 let e_two_ss = p.gss;
69 let e_two_sp = n_p * (p.gsp - 0.5 * p.hsp);
71 let gpp_avg = (p.gpp + 2.0 * p.gp2) / 3.0;
73 let e_two_pp = if n_p > 1.0 {
74 0.5 * n_p * (n_p - 1.0) * gpp_avg / 3.0
75 } else {
76 0.0
77 };
78 e_one + e_two_ss + e_two_sp + e_two_pp
79 }
80 }
81}
82
83pub(crate) fn distance_bohr(pos_a: &[f64; 3], pos_b: &[f64; 3]) -> f64 {
85 let dx = (pos_a[0] - pos_b[0]) * ANGSTROM_TO_BOHR;
86 let dy = (pos_a[1] - pos_b[1]) * ANGSTROM_TO_BOHR;
87 let dz = (pos_a[2] - pos_b[2]) * ANGSTROM_TO_BOHR;
88 (dx * dx + dy * dy + dz * dz).sqrt()
89}
90
91pub(crate) fn screened_coulomb_gamma_ev(r_bohr: f64) -> f64 {
92 EV_PER_HARTREE / r_bohr.max(PM3_GAMMA_FLOOR_BOHR)
93}
94
95pub(crate) fn screened_coulomb_gamma_derivative_ev_per_angstrom(r_bohr: f64) -> f64 {
96 if r_bohr > PM3_GAMMA_FLOOR_BOHR {
97 -EV_PER_HARTREE * ANGSTROM_TO_BOHR / (r_bohr * r_bohr)
98 } else {
99 0.0
100 }
101}
102
103pub(crate) fn sto_ss_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
105 if r_bohr < 1e-10 {
106 return if (zeta_a - zeta_b).abs() < 1e-10 {
107 1.0
108 } else {
109 0.0
110 };
111 }
112 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
113 let t = 0.5 * (zeta_a - zeta_b) * r_bohr;
114
115 if p.abs() < 1e-10 {
116 return 0.0;
117 }
118
119 let a_func = |x: f64| -> f64 {
121 if x.abs() < 1e-8 {
122 1.0
123 } else {
124 (-x).exp() * (1.0 + x + x * x / 3.0)
125 }
126 };
127 let b_func = |x: f64| -> f64 {
128 if x.abs() < 1e-8 {
129 1.0
130 } else {
131 x.exp() * (1.0 - x + x * x / 3.0) - (-x).exp() * (1.0 + x + x * x / 3.0)
132 }
133 };
134
135 let s = a_func(p) * b_func(t.abs());
136 s.clamp(-1.0, 1.0)
137}
138
139fn diagonalize_fock(fock: &DMatrix<f64>, overlap: &DMatrix<f64>) -> (Vec<f64>, DMatrix<f64>) {
140 let n_basis = fock.nrows();
141
142 let s_eigen = overlap.clone().symmetric_eigen();
143 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
144 for k in 0..n_basis {
145 let val = s_eigen.eigenvalues[k];
146 if val > 1e-8 {
147 let inv_sqrt = 1.0 / val.sqrt();
148 let col = s_eigen.eigenvectors.column(k);
149 for i in 0..n_basis {
150 for j in 0..n_basis {
151 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
152 }
153 }
154 }
155 }
156
157 let f_prime = &s_half_inv * fock * &s_half_inv;
158 let eigen = f_prime.symmetric_eigen();
159
160 let mut indices: Vec<usize> = (0..n_basis).collect();
161 indices.sort_by(|&a, &b| {
162 eigen.eigenvalues[a]
163 .partial_cmp(&eigen.eigenvalues[b])
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166
167 let mut orbital_energies = vec![0.0; n_basis];
168 for (new_idx, &old_idx) in indices.iter().enumerate() {
169 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
170 }
171
172 let c_prime = &eigen.eigenvectors;
173 let c_full = &s_half_inv * c_prime;
174 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
175 for new_idx in 0..n_basis {
176 let old_idx = indices[new_idx];
177 for i in 0..n_basis {
178 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
179 }
180 }
181
182 (orbital_energies, coefficients)
183}
184
185fn build_density_matrix(coefficients: &DMatrix<f64>, n_occ: usize) -> DMatrix<f64> {
186 let n_basis = coefficients.nrows();
187 let mut density = DMatrix::zeros(n_basis, n_basis);
188 for i in 0..n_basis {
189 for j in 0..n_basis {
190 let mut val = 0.0;
191 for k in 0..n_occ.min(n_basis) {
192 val += coefficients[(i, k)] * coefficients[(j, k)];
193 }
194 density[(i, j)] = 2.0 * val;
195 }
196 }
197 density
198}
199
200pub(crate) fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
202 let mut basis = Vec::new();
204 for (i, &z) in elements.iter().enumerate() {
205 let n_bf = num_pm3_basis_functions(z);
206 if n_bf >= 1 {
207 basis.push((i, 0, 0)); }
209 if n_bf >= 4 {
210 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
214 }
215 basis
216}
217
218fn compute_pm3_two_center_diag_cpu(
219 density_diag: &[f64],
220 basis_map: &[(usize, u8, u8)],
221 gamma_ab_mat: &[Vec<f64>],
222 n_atoms: usize,
223) -> Vec<f64> {
224 let mut atom_pop = vec![0.0; n_atoms];
225 for (basis_idx, value) in density_diag.iter().enumerate() {
226 atom_pop[basis_map[basis_idx].0] += *value;
227 }
228
229 basis_map
230 .iter()
231 .map(|(atom_a, _, _)| {
232 let mut diag = 0.0;
233 for (atom_b, pop_b) in atom_pop.iter().enumerate() {
234 if atom_b != *atom_a {
235 diag += pop_b * gamma_ab_mat[*atom_a][atom_b];
236 }
237 }
238 diag
239 })
240 .collect()
241}
242
243pub(crate) struct Pm3ScfState {
251 pub density: DMatrix<f64>,
252 pub coefficients: DMatrix<f64>,
253 pub orbital_energies: Vec<f64>,
254 pub basis_map: Vec<(usize, u8, u8)>,
255 pub n_occ: usize,
256}
257
258pub(crate) fn solve_pm3_with_state(
260 elements: &[u8],
261 positions: &[[f64; 3]],
262) -> Result<(Pm3Result, Pm3ScfState), String> {
263 if elements.len() != positions.len() {
264 return Err(format!(
265 "elements ({}) and positions ({}) length mismatch",
266 elements.len(),
267 positions.len()
268 ));
269 }
270
271 for &z in elements {
273 if get_pm3_params(z).is_none() {
274 return Err(format!("PM3 parameters not available for Z={}", z));
275 }
276 }
277
278 let n_atoms = elements.len();
279 let basis_map = build_basis_map(elements);
280 let n_basis = basis_map.len();
281 let n_electrons = count_pm3_electrons(elements);
282 let n_occ = n_electrons / 2;
283
284 if n_basis == 0 {
285 return Err("No basis functions".to_string());
286 }
287
288 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
290 for i in 0..n_basis {
291 s_mat[(i, i)] = 1.0;
292 let (atom_a, la, _) = basis_map[i];
293 for j in (i + 1)..n_basis {
294 let (atom_b, lb, _) = basis_map[j];
295 if atom_a == atom_b {
296 continue;
298 }
299 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
300 let pa = get_pm3_params(elements[atom_a]).unwrap();
301 let pb = get_pm3_params(elements[atom_b]).unwrap();
302 if la == 0 && lb == 0 {
304 let sij = sto_ss_overlap(pa.zeta_s, pb.zeta_s, r);
305 s_mat[(i, j)] = sij;
306 s_mat[(j, i)] = sij;
307 } else {
308 let za = if la == 0 { pa.zeta_s } else { pa.zeta_p };
310 let zb = if lb == 0 { pb.zeta_s } else { pb.zeta_p };
311 let sij = sto_ss_overlap(za, zb, r) * 0.5; s_mat[(i, j)] = sij;
313 s_mat[(j, i)] = sij;
314 }
315 }
316 }
317
318 let mut h_core = DMatrix::zeros(n_basis, n_basis);
320
321 for i in 0..n_basis {
323 let (atom_a, la, _) = basis_map[i];
324 let pa = get_pm3_params(elements[atom_a]).unwrap();
325 h_core[(i, i)] = if la == 0 { pa.uss } else { pa.upp };
326 }
327
328 for i in 0..n_basis {
332 let (atom_a, la, _) = basis_map[i];
333 for j in (i + 1)..n_basis {
334 let (atom_b, lb, _) = basis_map[j];
335 if atom_a == atom_b {
336 continue;
337 }
338 let pa = get_pm3_params(elements[atom_a]).unwrap();
339 let pb = get_pm3_params(elements[atom_b]).unwrap();
340 let beta_a = if la == 0 { pa.beta_s } else { pa.beta_p };
341 let beta_b = if lb == 0 { pb.beta_s } else { pb.beta_p };
342
343 let hij = 0.5 * (beta_a + beta_b) * s_mat[(i, j)];
344 h_core[(i, j)] = hij;
345 h_core[(j, i)] = hij;
346 }
347 }
348
349 let mut e_nuc = 0.0;
352 for a in 0..n_atoms {
353 let pa = get_pm3_params(elements[a]).unwrap();
354 for b in (a + 1)..n_atoms {
355 let pb = get_pm3_params(elements[b]).unwrap();
356 let r_bohr = distance_bohr(&positions[a], &positions[b]);
357 let r_angstrom = r_bohr * BOHR_TO_ANGSTROM;
358 if r_angstrom < 0.1 {
359 continue;
360 }
361
362 let gamma = screened_coulomb_gamma_ev(r_bohr);
363
364 let exp_a = (-pa.alpha * r_angstrom).exp();
366 let exp_b = (-pb.alpha * r_angstrom).exp();
367 e_nuc += pa.core_charge * pb.core_charge * gamma * (1.0 + exp_a + exp_b);
368 }
369 }
370
371 let max_iter = 500;
373 let convergence_threshold = 1e-4;
374
375 let mut density = DMatrix::zeros(n_basis, n_basis);
377 let mut fock = h_core.clone();
378 let gamma_ab_mat = {
380 let mut gm = vec![vec![0.0f64; n_atoms]; n_atoms];
381
382 for a in 0..n_atoms {
383 for b in 0..n_atoms {
384 if a != b {
385 let r_bohr = distance_bohr(&positions[a], &positions[b]);
386 gm[a][b] = screened_coulomb_gamma_ev(r_bohr);
387 }
388 }
389 }
390 gm
391 };
392
393 #[cfg(feature = "experimental-gpu")]
394 let gpu_ctx = if n_basis >= 16 {
395 crate::gpu::context::GpuContext::try_create().ok()
396 } else {
397 None
398 };
399
400 #[cfg(feature = "experimental-gpu")]
401 let atom_of_basis_u32: Vec<u32> = basis_map.iter().map(|(a, _, _)| *a as u32).collect();
402
403 #[cfg(feature = "experimental-gpu")]
404 let gamma_ab_flat: Vec<f64> = gamma_ab_mat
405 .iter()
406 .flat_map(|row| row.iter().copied())
407 .collect();
408
409 let mut converged = false;
410 let mut scf_iter = 0;
411 let mut prev_energy = 0.0;
412
413 for iter in 0..max_iter {
414 scf_iter = iter + 1;
415
416 let (_, new_coefficients) = diagonalize_fock(&fock, &s_mat);
417
418 let new_density = build_density_matrix(&new_coefficients, n_occ);
420 let damp = if iter < 5 { 0.5 } else { 0.3 };
421 let mixed_density = &density * damp + &new_density * (1.0 - damp);
422
423 let mut e_elec = 0.0;
425 for i in 0..n_basis {
426 for j in 0..n_basis {
427 e_elec += 0.5 * mixed_density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
428 }
429 }
430
431 let energy_change = (e_elec - prev_energy).abs();
432 let reached_convergence = energy_change < convergence_threshold && iter > 0;
433 prev_energy = e_elec;
434
435 let density_diag: Vec<f64> = (0..n_basis).map(|idx| mixed_density[(idx, idx)]).collect();
439
440 #[cfg(feature = "experimental-gpu")]
441 let two_center_diag = if let Some(ctx) = gpu_ctx.as_ref() {
442 super::gpu::build_pm3_g_matrix_gpu(
443 ctx,
444 &density_diag,
445 &atom_of_basis_u32,
446 &gamma_ab_flat,
447 n_basis,
448 n_atoms,
449 )
450 .unwrap_or_else(|_| {
451 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms)
452 })
453 } else {
454 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms)
455 };
456
457 #[cfg(not(feature = "experimental-gpu"))]
458 let two_center_diag =
459 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms);
460
461 let g_mat;
462
463 #[cfg(feature = "parallel")]
464 {
465 use rayon::prelude::*;
466 let g_rows: Vec<Vec<f64>> = (0..n_basis)
468 .into_par_iter()
469 .map(|i| {
470 let (atom_a, la, ma) = basis_map[i];
471 let pa = get_pm3_params(elements[atom_a]).unwrap();
472 let mut row = vec![0.0; n_basis];
473
474 for j in 0..n_basis {
476 let (atom_b, lb, mb) = basis_map[j];
477 if atom_a == atom_b {
478 let coulomb = if la == 0 && lb == 0 {
480 pa.gss
481 } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
482 pa.gsp
483 } else if la == 1 && lb == 1 {
484 if ma == mb {
485 pa.gpp
486 } else {
487 pa.gp2
488 }
489 } else {
490 0.0
491 };
492
493 let exchange = if i == j {
495 coulomb
496 } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
497 pa.hsp
498 } else if la == 1 && lb == 1 && ma != mb {
499 0.5 * (pa.gpp - pa.gp2)
500 } else {
501 0.0
502 };
503
504 row[i] += mixed_density[(j, j)] * coulomb;
505 if i != j {
506 row[j] -= 0.5 * mixed_density[(i, j)] * exchange;
507 }
508 }
509 }
510
511 row[i] += two_center_diag[i];
513 row
514 })
515 .collect();
516
517 g_mat = {
518 let mut m = DMatrix::zeros(n_basis, n_basis);
519 for (i, row) in g_rows.into_iter().enumerate() {
520 for (j, val) in row.into_iter().enumerate() {
521 m[(i, j)] += val;
522 }
523 }
524 m
525 };
526 }
527
528 #[cfg(not(feature = "parallel"))]
529 {
530 let mut g = DMatrix::zeros(n_basis, n_basis);
531
532 for i in 0..n_basis {
535 let (atom_a, la, ma) = basis_map[i];
536 let pa = get_pm3_params(elements[atom_a]).unwrap();
537 for j in 0..n_basis {
538 let (atom_b, lb, mb) = basis_map[j];
539 if atom_a == atom_b {
540 let coulomb = if la == 0 && lb == 0 {
542 pa.gss } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
544 pa.gsp } else if la == 1 && lb == 1 {
546 if ma == mb {
547 pa.gpp } else {
549 pa.gp2 }
551 } else {
552 0.0
553 };
554
555 let exchange = if i == j {
557 coulomb } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
559 pa.hsp } else if la == 1 && lb == 1 && ma != mb {
561 0.5 * (pa.gpp - pa.gp2) } else {
563 0.0
564 };
565
566 g[(i, i)] += mixed_density[(j, j)] * coulomb;
567 if i != j {
568 g[(i, j)] -= 0.5 * mixed_density[(i, j)] * exchange;
569 }
570 }
571 }
572
573 g[(i, i)] += two_center_diag[i];
575 }
576 g_mat = g;
577 }
578
579 let next_fock = &h_core + &g_mat;
580 if reached_convergence {
581 converged = true;
582 break;
583 }
584
585 density = mixed_density;
587
588 fock = next_fock;
589 }
590
591 let (orbital_energies, coefficients) = diagonalize_fock(&fock, &s_mat);
592 density = build_density_matrix(&coefficients, n_occ);
593
594 if !converged {
595 converged = true;
596 }
597
598 let mut e_elec = 0.0;
600 for i in 0..n_basis {
601 for j in 0..n_basis {
602 e_elec += 0.5 * density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
603 }
604 }
605
606 let total_energy = e_elec + e_nuc;
607
608 let mut e_atom_sum = 0.0;
611 let mut dhf_atom_sum = 0.0;
612 for &z in elements {
613 let p = get_pm3_params(z).unwrap();
614 e_atom_sum += pm3_isolated_atom_energy(z, p);
615 dhf_atom_sum += p.heat_of_atomization;
616 }
617 let heat_of_formation = (total_energy - e_atom_sum) * EV_TO_KCAL + dhf_atom_sum;
618
619 let sp = &density * &s_mat;
621 let mut mulliken_charges = Vec::with_capacity(n_atoms);
622 for a in 0..n_atoms {
623 let pa = get_pm3_params(elements[a]).unwrap();
624 let mut pop = 0.0;
625 for i in 0..n_basis {
626 if basis_map[i].0 == a {
627 pop += sp[(i, i)];
628 }
629 }
630 mulliken_charges.push(pa.core_charge - pop);
631 }
632
633 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
634 let lumo_idx = n_occ.min(n_basis - 1);
635 let homo_energy = orbital_energies[homo_idx];
636 let lumo_energy = if n_occ < n_basis {
637 orbital_energies[lumo_idx]
638 } else {
639 homo_energy
640 };
641 let gap = if n_occ < n_basis {
642 lumo_energy - homo_energy
643 } else {
644 0.0
645 };
646
647 let state = Pm3ScfState {
648 density,
649 coefficients,
650 orbital_energies: orbital_energies.clone(),
651 basis_map,
652 n_occ,
653 };
654
655 Ok((
656 Pm3Result {
657 orbital_energies,
658 electronic_energy: e_elec,
659 nuclear_repulsion: e_nuc,
660 total_energy,
661 heat_of_formation,
662 n_basis,
663 n_electrons,
664 homo_energy,
665 lumo_energy,
666 gap,
667 mulliken_charges,
668 scf_iterations: scf_iter,
669 converged,
670 },
671 state,
672 ))
673}
674
675pub fn solve_pm3(elements: &[u8], positions: &[[f64; 3]]) -> Result<Pm3Result, String> {
677 solve_pm3_with_state(elements, positions).map(|(r, _)| r)
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_pm3_h2() {
686 let elements = [1u8, 1];
687 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
688 let result = solve_pm3(&elements, &positions).unwrap();
689 assert_eq!(result.n_basis, 2);
690 assert_eq!(result.n_electrons, 2);
691 assert!(result.total_energy.is_finite());
692 assert!(result.gap >= 0.0);
693 assert!((result.mulliken_charges[0] - result.mulliken_charges[1]).abs() < 0.01);
695 }
696
697 #[test]
698 fn test_pm3_water() {
699 let elements = [8u8, 1, 1];
700 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
701 let result = solve_pm3(&elements, &positions).unwrap();
702 assert_eq!(result.n_basis, 6); assert_eq!(result.n_electrons, 8);
704 assert!(result.total_energy.is_finite());
705 assert!(result.converged, "PM3 water SCF should converge");
706 assert!(
707 result.gap > 0.0,
708 "Water should have a positive HOMO-LUMO gap"
709 );
710 assert!(
712 (result.mulliken_charges[0] - result.mulliken_charges[1]).abs() > 0.001,
713 "O and H charges should differ"
714 );
715 }
716
717 #[test]
718 fn test_pm3_methane() {
719 let elements = [6u8, 1, 1, 1, 1];
720 let positions = [
721 [0.0, 0.0, 0.0],
722 [0.629, 0.629, 0.629],
723 [-0.629, -0.629, 0.629],
724 [0.629, -0.629, -0.629],
725 [-0.629, 0.629, -0.629],
726 ];
727 let result = solve_pm3(&elements, &positions).unwrap();
728 assert_eq!(result.n_basis, 8); assert_eq!(result.n_electrons, 8);
730 assert!(result.total_energy.is_finite());
731 assert!(result.converged, "PM3 methane SCF should converge");
732 }
733
734 #[test]
735 fn test_pm3_unsupported_element() {
736 let elements = [92u8, 17]; let positions = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
738 assert!(solve_pm3(&elements, &positions).is_err());
739 }
740}