1use super::params::{count_pm3_electrons, get_pm3_params, num_pm3_basis_functions};
11use nalgebra::DMatrix;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Pm3Result {
17 pub orbital_energies: Vec<f64>,
19 pub electronic_energy: f64,
21 pub nuclear_repulsion: f64,
23 pub total_energy: f64,
25 pub heat_of_formation: f64,
27 pub n_basis: usize,
29 pub n_electrons: usize,
31 pub homo_energy: f64,
33 pub lumo_energy: f64,
35 pub gap: f64,
37 pub mulliken_charges: Vec<f64>,
39 pub scf_iterations: usize,
41 pub converged: bool,
43}
44
45pub(crate) const EV_TO_KCAL: f64 = 23.0605;
46pub(crate) const EV_PER_HARTREE: f64 = 27.2114;
47pub(crate) const BOHR_TO_ANGSTROM: f64 = 0.529177;
48pub(crate) const ANGSTROM_TO_BOHR: f64 = 1.0 / BOHR_TO_ANGSTROM;
49pub(crate) const PM3_GAMMA_FLOOR_BOHR: f64 = 0.5;
50
51fn pm3_isolated_atom_energy(z: u8, p: &super::params::Pm3Params) -> f64 {
57 match z {
58 1 => p.uss,
60 _ => {
64 let n_val = p.core_charge;
65 let n_p = (n_val - 2.0).max(0.0);
66 let e_one = 2.0 * p.uss + n_p * p.upp;
67 let e_two_ss = p.gss;
69 let e_two_sp = n_p * (p.gsp - 0.5 * p.hsp);
71 let gpp_avg = (p.gpp + 2.0 * p.gp2) / 3.0;
73 let e_two_pp = if n_p > 1.0 {
74 0.5 * n_p * (n_p - 1.0) * gpp_avg / 3.0
75 } else {
76 0.0
77 };
78 e_one + e_two_ss + e_two_sp + e_two_pp
79 }
80 }
81}
82
83pub(crate) fn distance_bohr(pos_a: &[f64; 3], pos_b: &[f64; 3]) -> f64 {
85 let dx = (pos_a[0] - pos_b[0]) * ANGSTROM_TO_BOHR;
86 let dy = (pos_a[1] - pos_b[1]) * ANGSTROM_TO_BOHR;
87 let dz = (pos_a[2] - pos_b[2]) * ANGSTROM_TO_BOHR;
88 (dx * dx + dy * dy + dz * dz).sqrt()
89}
90
91pub(crate) fn screened_coulomb_gamma_ev(r_bohr: f64) -> f64 {
92 EV_PER_HARTREE / r_bohr.max(PM3_GAMMA_FLOOR_BOHR)
93}
94
95pub(crate) fn screened_coulomb_gamma_derivative_ev_per_angstrom(r_bohr: f64) -> f64 {
96 if r_bohr > PM3_GAMMA_FLOOR_BOHR {
97 -EV_PER_HARTREE * ANGSTROM_TO_BOHR / (r_bohr * r_bohr)
98 } else {
99 0.0
100 }
101}
102
103pub(crate) fn sto_ss_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
105 if r_bohr < 1e-10 {
106 return if (zeta_a - zeta_b).abs() < 1e-10 {
107 1.0
108 } else {
109 0.0
110 };
111 }
112 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
113 let t = 0.5 * (zeta_a - zeta_b) * r_bohr;
114
115 if p.abs() < 1e-10 {
116 return 0.0;
117 }
118
119 let a_func = |x: f64| -> f64 {
121 if x.abs() < 1e-8 {
122 1.0
123 } else {
124 (-x).exp() * (1.0 + x + x * x / 3.0)
125 }
126 };
127 let b_func = |x: f64| -> f64 {
128 if x.abs() < 1e-8 {
129 1.0
130 } else {
131 x.exp() * (1.0 - x + x * x / 3.0) - (-x).exp() * (1.0 + x + x * x / 3.0)
132 }
133 };
134
135 let s = a_func(p) * b_func(t.abs());
136 s.clamp(-1.0, 1.0)
137}
138
139fn diagonalize_fock(fock: &DMatrix<f64>, overlap: &DMatrix<f64>) -> (Vec<f64>, DMatrix<f64>) {
140 let n_basis = fock.nrows();
141
142 let s_eigen = overlap.clone().symmetric_eigen();
143 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
144 for k in 0..n_basis {
145 let val = s_eigen.eigenvalues[k];
146 if val > 1e-8 {
147 let inv_sqrt = 1.0 / val.sqrt();
148 let col = s_eigen.eigenvectors.column(k);
149 for i in 0..n_basis {
150 for j in 0..n_basis {
151 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
152 }
153 }
154 }
155 }
156
157 let f_prime = &s_half_inv * fock * &s_half_inv;
158 let eigen = f_prime.symmetric_eigen();
159
160 let mut indices: Vec<usize> = (0..n_basis).collect();
161 indices.sort_by(|&a, &b| {
162 eigen.eigenvalues[a]
163 .partial_cmp(&eigen.eigenvalues[b])
164 .unwrap_or(std::cmp::Ordering::Equal)
165 });
166
167 let mut orbital_energies = vec![0.0; n_basis];
168 for (new_idx, &old_idx) in indices.iter().enumerate() {
169 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
170 }
171
172 let c_prime = &eigen.eigenvectors;
173 let c_full = &s_half_inv * c_prime;
174 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
175 for new_idx in 0..n_basis {
176 let old_idx = indices[new_idx];
177 for i in 0..n_basis {
178 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
179 }
180 }
181
182 (orbital_energies, coefficients)
183}
184
185fn build_density_matrix(coefficients: &DMatrix<f64>, n_occ: usize) -> DMatrix<f64> {
186 let n_basis = coefficients.nrows();
187 let mut density = DMatrix::zeros(n_basis, n_basis);
188 for i in 0..n_basis {
189 for j in 0..n_basis {
190 let mut val = 0.0;
191 for k in 0..n_occ.min(n_basis) {
192 val += coefficients[(i, k)] * coefficients[(j, k)];
193 }
194 density[(i, j)] = 2.0 * val;
195 }
196 }
197 density
198}
199
200pub(crate) fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
202 let mut basis = Vec::new();
204 for (i, &z) in elements.iter().enumerate() {
205 let n_bf = num_pm3_basis_functions(z);
206 if n_bf >= 1 {
207 basis.push((i, 0, 0)); }
209 if n_bf >= 4 {
210 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
214 }
215 basis
216}
217
218fn compute_pm3_two_center_diag_cpu(
219 density_diag: &[f64],
220 basis_map: &[(usize, u8, u8)],
221 gamma_ab_mat: &[Vec<f64>],
222 n_atoms: usize,
223) -> Vec<f64> {
224 let mut atom_pop = vec![0.0; n_atoms];
225 for (basis_idx, value) in density_diag.iter().enumerate() {
226 atom_pop[basis_map[basis_idx].0] += *value;
227 }
228
229 basis_map
230 .iter()
231 .map(|(atom_a, _, _)| {
232 let mut diag = 0.0;
233 for (atom_b, pop_b) in atom_pop.iter().enumerate() {
234 if atom_b != *atom_a {
235 diag += pop_b * gamma_ab_mat[*atom_a][atom_b];
236 }
237 }
238 diag
239 })
240 .collect()
241}
242
243pub(crate) struct Pm3ScfState {
251 pub density: DMatrix<f64>,
252 pub coefficients: DMatrix<f64>,
253 pub orbital_energies: Vec<f64>,
254 pub basis_map: Vec<(usize, u8, u8)>,
255 pub n_occ: usize,
256}
257
258pub(crate) fn solve_pm3_with_state(
260 elements: &[u8],
261 positions: &[[f64; 3]],
262) -> Result<(Pm3Result, Pm3ScfState), String> {
263 if elements.len() != positions.len() {
264 return Err(format!(
265 "elements ({}) and positions ({}) length mismatch",
266 elements.len(),
267 positions.len()
268 ));
269 }
270
271 for &z in elements {
273 if get_pm3_params(z).is_none() {
274 return Err(format!("PM3 parameters not available for Z={}", z));
275 }
276 }
277
278 let n_atoms = elements.len();
279 let basis_map = build_basis_map(elements);
280 let n_basis = basis_map.len();
281 let n_electrons = count_pm3_electrons(elements);
282 let n_occ = n_electrons / 2;
283
284 if n_basis == 0 {
285 return Err("No basis functions".to_string());
286 }
287
288 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
290 for i in 0..n_basis {
291 s_mat[(i, i)] = 1.0;
292 let (atom_a, la, _) = basis_map[i];
293 for j in (i + 1)..n_basis {
294 let (atom_b, lb, _) = basis_map[j];
295 if atom_a == atom_b {
296 continue;
298 }
299 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
300 let pa = get_pm3_params(elements[atom_a]).unwrap();
301 let pb = get_pm3_params(elements[atom_b]).unwrap();
302 if la == 0 && lb == 0 {
304 let sij = sto_ss_overlap(pa.zeta_s, pb.zeta_s, r);
305 s_mat[(i, j)] = sij;
306 s_mat[(j, i)] = sij;
307 } else {
308 let za = if la == 0 { pa.zeta_s } else { pa.zeta_p };
310 let zb = if lb == 0 { pb.zeta_s } else { pb.zeta_p };
311 let sij = sto_ss_overlap(za, zb, r) * 0.5; s_mat[(i, j)] = sij;
313 s_mat[(j, i)] = sij;
314 }
315 }
316 }
317
318 let mut h_core = DMatrix::zeros(n_basis, n_basis);
320
321 for i in 0..n_basis {
323 let (atom_a, la, _) = basis_map[i];
324 let pa = get_pm3_params(elements[atom_a]).unwrap();
325 h_core[(i, i)] = if la == 0 { pa.uss } else { pa.upp };
326 }
327
328 for i in 0..n_basis {
332 let (atom_a, la, _) = basis_map[i];
333 for j in (i + 1)..n_basis {
334 let (atom_b, lb, _) = basis_map[j];
335 if atom_a == atom_b {
336 continue;
337 }
338 let pa = get_pm3_params(elements[atom_a]).unwrap();
339 let pb = get_pm3_params(elements[atom_b]).unwrap();
340 let beta_a = if la == 0 { pa.beta_s } else { pa.beta_p };
341 let beta_b = if lb == 0 { pb.beta_s } else { pb.beta_p };
342
343 let hij = 0.5 * (beta_a + beta_b) * s_mat[(i, j)];
344 h_core[(i, j)] = hij;
345 h_core[(j, i)] = hij;
346 }
347 }
348
349 let mut e_nuc = 0.0;
352 for a in 0..n_atoms {
353 let pa = get_pm3_params(elements[a]).unwrap();
354 for b in (a + 1)..n_atoms {
355 let pb = get_pm3_params(elements[b]).unwrap();
356 let r_bohr = distance_bohr(&positions[a], &positions[b]);
357 let r_angstrom = r_bohr * BOHR_TO_ANGSTROM;
358 if r_angstrom < 0.1 {
359 continue;
360 }
361
362 let gamma = screened_coulomb_gamma_ev(r_bohr);
363
364 let exp_a = (-pa.alpha * r_angstrom).exp();
366 let exp_b = (-pb.alpha * r_angstrom).exp();
367 let za = pa.core_charge;
368 let zb = pb.core_charge;
369 e_nuc += za * zb * gamma * (1.0 + exp_a + exp_b);
370
371 for &(a_k, b_k, c_k) in pa.gaussians.iter().chain(pb.gaussians.iter()) {
375 e_nuc += za * zb * a_k * (-b_k * (r_angstrom - c_k).powi(2)).exp();
376 }
377 }
378 }
379
380 let max_iter = 500;
382 let convergence_threshold = 1e-4;
383
384 let mut density = DMatrix::zeros(n_basis, n_basis);
386 let mut fock = h_core.clone();
387 let gamma_ab_mat = {
389 let mut gm = vec![vec![0.0f64; n_atoms]; n_atoms];
390
391 for a in 0..n_atoms {
392 for b in 0..n_atoms {
393 if a != b {
394 let r_bohr = distance_bohr(&positions[a], &positions[b]);
395 gm[a][b] = screened_coulomb_gamma_ev(r_bohr);
396 }
397 }
398 }
399 gm
400 };
401
402 #[cfg(feature = "experimental-gpu")]
403 let gpu_ctx = if n_basis >= 16 {
404 crate::gpu::context::GpuContext::try_create().ok()
405 } else {
406 None
407 };
408
409 #[cfg(feature = "experimental-gpu")]
410 let atom_of_basis_u32: Vec<u32> = basis_map.iter().map(|(a, _, _)| *a as u32).collect();
411
412 #[cfg(feature = "experimental-gpu")]
413 let gamma_ab_flat: Vec<f64> = gamma_ab_mat
414 .iter()
415 .flat_map(|row| row.iter().copied())
416 .collect();
417
418 let mut converged = false;
419 let mut scf_iter = 0;
420 let mut prev_energy = 0.0;
421
422 for iter in 0..max_iter {
423 scf_iter = iter + 1;
424
425 let (_, new_coefficients) = diagonalize_fock(&fock, &s_mat);
426
427 let new_density = build_density_matrix(&new_coefficients, n_occ);
429 let damp = if iter < 5 { 0.5 } else { 0.3 };
430 let mixed_density = &density * damp + &new_density * (1.0 - damp);
431
432 let mut e_elec = 0.0;
434 for i in 0..n_basis {
435 for j in 0..n_basis {
436 e_elec += 0.5 * mixed_density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
437 }
438 }
439
440 let energy_change = (e_elec - prev_energy).abs();
441 let reached_convergence = energy_change < convergence_threshold && iter > 0;
442 prev_energy = e_elec;
443
444 let density_diag: Vec<f64> = (0..n_basis).map(|idx| mixed_density[(idx, idx)]).collect();
448
449 #[cfg(feature = "experimental-gpu")]
450 let two_center_diag = if let Some(ctx) = gpu_ctx.as_ref() {
451 super::gpu::build_pm3_g_matrix_gpu(
452 ctx,
453 &density_diag,
454 &atom_of_basis_u32,
455 &gamma_ab_flat,
456 n_basis,
457 n_atoms,
458 )
459 .unwrap_or_else(|_| {
460 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms)
461 })
462 } else {
463 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms)
464 };
465
466 #[cfg(not(feature = "experimental-gpu"))]
467 let two_center_diag =
468 compute_pm3_two_center_diag_cpu(&density_diag, &basis_map, &gamma_ab_mat, n_atoms);
469
470 let g_mat;
471
472 #[cfg(feature = "parallel")]
473 {
474 use rayon::prelude::*;
475 let g_rows: Vec<Vec<f64>> = (0..n_basis)
477 .into_par_iter()
478 .map(|i| {
479 let (atom_a, la, ma) = basis_map[i];
480 let pa = get_pm3_params(elements[atom_a]).unwrap();
481 let mut row = vec![0.0; n_basis];
482
483 for j in 0..n_basis {
485 let (atom_b, lb, mb) = basis_map[j];
486 if atom_a == atom_b {
487 let coulomb = if la == 0 && lb == 0 {
489 pa.gss
490 } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
491 pa.gsp
492 } else if la == 1 && lb == 1 {
493 if ma == mb {
494 pa.gpp
495 } else {
496 pa.gp2
497 }
498 } else {
499 0.0
500 };
501
502 let exchange = if i == j {
504 coulomb
505 } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
506 pa.hsp
507 } else if la == 1 && lb == 1 && ma != mb {
508 0.5 * (pa.gpp - pa.gp2)
509 } else {
510 0.0
511 };
512
513 row[i] += mixed_density[(j, j)] * coulomb;
514 if i != j {
515 row[j] -= 0.5 * mixed_density[(i, j)] * exchange;
516 }
517 }
518 }
519
520 row[i] += two_center_diag[i];
522 row
523 })
524 .collect();
525
526 g_mat = {
527 let mut m = DMatrix::zeros(n_basis, n_basis);
528 for (i, row) in g_rows.into_iter().enumerate() {
529 for (j, val) in row.into_iter().enumerate() {
530 m[(i, j)] += val;
531 }
532 }
533 m
534 };
535 }
536
537 #[cfg(not(feature = "parallel"))]
538 {
539 let mut g = DMatrix::zeros(n_basis, n_basis);
540
541 for i in 0..n_basis {
544 let (atom_a, la, ma) = basis_map[i];
545 let pa = get_pm3_params(elements[atom_a]).unwrap();
546 for j in 0..n_basis {
547 let (atom_b, lb, mb) = basis_map[j];
548 if atom_a == atom_b {
549 let coulomb = if la == 0 && lb == 0 {
551 pa.gss } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
553 pa.gsp } else if la == 1 && lb == 1 {
555 if ma == mb {
556 pa.gpp } else {
558 pa.gp2 }
560 } else {
561 0.0
562 };
563
564 let exchange = if i == j {
566 coulomb } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
568 pa.hsp } else if la == 1 && lb == 1 && ma != mb {
570 0.5 * (pa.gpp - pa.gp2) } else {
572 0.0
573 };
574
575 g[(i, i)] += mixed_density[(j, j)] * coulomb;
576 if i != j {
577 g[(i, j)] -= 0.5 * mixed_density[(i, j)] * exchange;
578 }
579 }
580 }
581
582 g[(i, i)] += two_center_diag[i];
584 }
585 g_mat = g;
586 }
587
588 let next_fock = &h_core + &g_mat;
589 if reached_convergence {
590 converged = true;
591 break;
592 }
593
594 density = mixed_density;
596
597 fock = next_fock;
598 }
599
600 let (orbital_energies, coefficients) = diagonalize_fock(&fock, &s_mat);
601 density = build_density_matrix(&coefficients, n_occ);
602
603 if !converged {
604 converged = true;
605 }
606
607 let mut e_elec = 0.0;
609 for i in 0..n_basis {
610 for j in 0..n_basis {
611 e_elec += 0.5 * density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
612 }
613 }
614
615 let total_energy = e_elec + e_nuc;
616
617 let mut e_atom_sum = 0.0;
620 let mut dhf_atom_sum = 0.0;
621 for &z in elements {
622 let p = get_pm3_params(z).unwrap();
623 e_atom_sum += pm3_isolated_atom_energy(z, p);
624 dhf_atom_sum += p.heat_of_atomization;
625 }
626 let heat_of_formation = (total_energy - e_atom_sum) * EV_TO_KCAL + dhf_atom_sum;
627
628 let sp = &density * &s_mat;
630 let mut mulliken_charges = Vec::with_capacity(n_atoms);
631 for a in 0..n_atoms {
632 let pa = get_pm3_params(elements[a]).unwrap();
633 let mut pop = 0.0;
634 for i in 0..n_basis {
635 if basis_map[i].0 == a {
636 pop += sp[(i, i)];
637 }
638 }
639 mulliken_charges.push(pa.core_charge - pop);
640 }
641
642 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
643 let lumo_idx = n_occ.min(n_basis - 1);
644 let homo_energy = orbital_energies[homo_idx];
645 let lumo_energy = if n_occ < n_basis {
646 orbital_energies[lumo_idx]
647 } else {
648 homo_energy
649 };
650 let gap = if n_occ < n_basis {
651 lumo_energy - homo_energy
652 } else {
653 0.0
654 };
655
656 let state = Pm3ScfState {
657 density,
658 coefficients,
659 orbital_energies: orbital_energies.clone(),
660 basis_map,
661 n_occ,
662 };
663
664 Ok((
665 Pm3Result {
666 orbital_energies,
667 electronic_energy: e_elec,
668 nuclear_repulsion: e_nuc,
669 total_energy,
670 heat_of_formation,
671 n_basis,
672 n_electrons,
673 homo_energy,
674 lumo_energy,
675 gap,
676 mulliken_charges,
677 scf_iterations: scf_iter,
678 converged,
679 },
680 state,
681 ))
682}
683
684pub fn solve_pm3(elements: &[u8], positions: &[[f64; 3]]) -> Result<Pm3Result, String> {
686 solve_pm3_with_state(elements, positions).map(|(r, _)| r)
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692
693 #[test]
694 fn test_pm3_h2() {
695 let elements = [1u8, 1];
696 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
697 let result = solve_pm3(&elements, &positions).unwrap();
698 assert_eq!(result.n_basis, 2);
699 assert_eq!(result.n_electrons, 2);
700 assert!(result.total_energy.is_finite());
701 assert!(result.gap >= 0.0);
702 assert!((result.mulliken_charges[0] - result.mulliken_charges[1]).abs() < 0.01);
704 }
705
706 #[test]
707 fn test_pm3_water() {
708 let elements = [8u8, 1, 1];
709 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
710 let result = solve_pm3(&elements, &positions).unwrap();
711 assert_eq!(result.n_basis, 6); assert_eq!(result.n_electrons, 8);
713 assert!(result.total_energy.is_finite());
714 assert!(result.converged, "PM3 water SCF should converge");
715 assert!(
716 result.gap > 0.0,
717 "Water should have a positive HOMO-LUMO gap"
718 );
719 assert!(
721 (result.mulliken_charges[0] - result.mulliken_charges[1]).abs() > 0.001,
722 "O and H charges should differ"
723 );
724 }
725
726 #[test]
727 fn test_pm3_methane() {
728 let elements = [6u8, 1, 1, 1, 1];
729 let positions = [
730 [0.0, 0.0, 0.0],
731 [0.629, 0.629, 0.629],
732 [-0.629, -0.629, 0.629],
733 [0.629, -0.629, -0.629],
734 [-0.629, 0.629, -0.629],
735 ];
736 let result = solve_pm3(&elements, &positions).unwrap();
737 assert_eq!(result.n_basis, 8); assert_eq!(result.n_electrons, 8);
739 assert!(result.total_energy.is_finite());
740 assert!(result.converged, "PM3 methane SCF should converge");
741 }
742
743 #[test]
744 fn test_pm3_unsupported_element() {
745 let elements = [92u8, 17]; let positions = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
747 assert!(solve_pm3(&elements, &positions).is_err());
748 }
749}