1use super::params::{count_pm3_electrons, get_pm3_params, num_pm3_basis_functions};
11use nalgebra::DMatrix;
12use serde::{Deserialize, Serialize};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Pm3Result {
17 pub orbital_energies: Vec<f64>,
19 pub electronic_energy: f64,
21 pub nuclear_repulsion: f64,
23 pub total_energy: f64,
25 pub heat_of_formation: f64,
27 pub n_basis: usize,
29 pub n_electrons: usize,
31 pub homo_energy: f64,
33 pub lumo_energy: f64,
35 pub gap: f64,
37 pub mulliken_charges: Vec<f64>,
39 pub scf_iterations: usize,
41 pub converged: bool,
43}
44
45const EV_TO_KCAL: f64 = 23.0605;
46const BOHR_TO_ANGSTROM: f64 = 0.529177;
47const ANGSTROM_TO_BOHR: f64 = 1.0 / BOHR_TO_ANGSTROM;
48
49fn distance_bohr(pos_a: &[f64; 3], pos_b: &[f64; 3]) -> f64 {
51 let dx = (pos_a[0] - pos_b[0]) * ANGSTROM_TO_BOHR;
52 let dy = (pos_a[1] - pos_b[1]) * ANGSTROM_TO_BOHR;
53 let dz = (pos_a[2] - pos_b[2]) * ANGSTROM_TO_BOHR;
54 (dx * dx + dy * dy + dz * dz).sqrt()
55}
56
57fn sto_ss_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
59 if r_bohr < 1e-10 {
60 return if (zeta_a - zeta_b).abs() < 1e-10 {
61 1.0
62 } else {
63 0.0
64 };
65 }
66 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
67 let t = 0.5 * (zeta_a - zeta_b) * r_bohr;
68
69 if p.abs() < 1e-10 {
70 return 0.0;
71 }
72
73 let a_func = |x: f64| -> f64 {
75 if x.abs() < 1e-8 {
76 1.0
77 } else {
78 (-x).exp() * (1.0 + x + x * x / 3.0)
79 }
80 };
81 let b_func = |x: f64| -> f64 {
82 if x.abs() < 1e-8 {
83 1.0
84 } else {
85 x.exp() * (1.0 - x + x * x / 3.0) - (-x).exp() * (1.0 + x + x * x / 3.0)
86 }
87 };
88
89 let s = a_func(p) * b_func(t.abs());
90 s.clamp(-1.0, 1.0)
91}
92
93fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
95 let mut basis = Vec::new();
97 for (i, &z) in elements.iter().enumerate() {
98 let n_bf = num_pm3_basis_functions(z);
99 if n_bf >= 1 {
100 basis.push((i, 0, 0)); }
102 if n_bf >= 4 {
103 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
107 }
108 basis
109}
110
111pub fn solve_pm3(elements: &[u8], positions: &[[f64; 3]]) -> Result<Pm3Result, String> {
118 if elements.len() != positions.len() {
119 return Err(format!(
120 "elements ({}) and positions ({}) length mismatch",
121 elements.len(),
122 positions.len()
123 ));
124 }
125
126 for &z in elements {
128 if get_pm3_params(z).is_none() {
129 return Err(format!("PM3 parameters not available for Z={}", z));
130 }
131 }
132
133 let n_atoms = elements.len();
134 let basis_map = build_basis_map(elements);
135 let n_basis = basis_map.len();
136 let n_electrons = count_pm3_electrons(elements);
137 let n_occ = n_electrons / 2;
138
139 if n_basis == 0 {
140 return Err("No basis functions".to_string());
141 }
142
143 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
145 for i in 0..n_basis {
146 s_mat[(i, i)] = 1.0;
147 let (atom_a, la, _) = basis_map[i];
148 for j in (i + 1)..n_basis {
149 let (atom_b, lb, _) = basis_map[j];
150 if atom_a == atom_b {
151 continue;
153 }
154 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
155 let pa = get_pm3_params(elements[atom_a]).unwrap();
156 let pb = get_pm3_params(elements[atom_b]).unwrap();
157 if la == 0 && lb == 0 {
159 let sij = sto_ss_overlap(pa.zeta_s, pb.zeta_s, r);
160 s_mat[(i, j)] = sij;
161 s_mat[(j, i)] = sij;
162 } else {
163 let za = if la == 0 { pa.zeta_s } else { pa.zeta_p };
165 let zb = if lb == 0 { pb.zeta_s } else { pb.zeta_p };
166 let sij = sto_ss_overlap(za, zb, r) * 0.5; s_mat[(i, j)] = sij;
168 s_mat[(j, i)] = sij;
169 }
170 }
171 }
172
173 let mut h_core = DMatrix::zeros(n_basis, n_basis);
175
176 for i in 0..n_basis {
178 let (atom_a, la, _) = basis_map[i];
179 let pa = get_pm3_params(elements[atom_a]).unwrap();
180 h_core[(i, i)] = if la == 0 { pa.uss } else { pa.upp };
181 }
182
183 for i in 0..n_basis {
185 let (atom_a, la, _) = basis_map[i];
186 for j in (i + 1)..n_basis {
187 let (atom_b, lb, _) = basis_map[j];
188 if atom_a == atom_b {
189 continue;
190 }
191 let pa = get_pm3_params(elements[atom_a]).unwrap();
192 let pb = get_pm3_params(elements[atom_b]).unwrap();
193 let beta_a = if la == 0 { pa.beta_s } else { pa.beta_p };
194 let beta_b = if lb == 0 { pb.beta_s } else { pb.beta_p };
195 let hij = 0.5 * (beta_a + beta_b) * s_mat[(i, j)];
196 h_core[(i, j)] = hij;
197 h_core[(j, i)] = hij;
198 }
199 }
200
201 let mut e_nuc = 0.0;
203 for a in 0..n_atoms {
204 let pa = get_pm3_params(elements[a]).unwrap();
205 for b in (a + 1)..n_atoms {
206 let pb = get_pm3_params(elements[b]).unwrap();
207 let r_bohr = distance_bohr(&positions[a], &positions[b]);
208 let r_angstrom = r_bohr * BOHR_TO_ANGSTROM;
209 if r_angstrom < 0.1 {
210 continue;
211 }
212
213 let _gamma_ss = 1.0 / r_bohr; let ev_per_hartree = 27.2114;
216 let gamma = ev_per_hartree / r_bohr.max(0.1);
217
218 let alpha_term = (-pa.alpha * r_angstrom).exp() + (-pb.alpha * r_angstrom).exp();
219 e_nuc += pa.core_charge * pb.core_charge * gamma * (1.0 + alpha_term);
220 }
221 }
222
223 let max_iter = 100;
225 let convergence_threshold = 1e-6;
226
227 let mut density = DMatrix::zeros(n_basis, n_basis);
229 let mut fock = h_core.clone();
230 let mut orbital_energies = vec![0.0; n_basis];
231 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
232 let mut converged = false;
233 let mut scf_iter = 0;
234 let mut prev_energy = 0.0;
235
236 for iter in 0..max_iter {
237 scf_iter = iter + 1;
238
239 let s_eigen = s_mat.clone().symmetric_eigen();
242 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
243 for k in 0..n_basis {
244 let val = s_eigen.eigenvalues[k];
245 if val > 1e-8 {
246 let inv_sqrt = 1.0 / val.sqrt();
247 let col = s_eigen.eigenvectors.column(k);
248 for i in 0..n_basis {
249 for j in 0..n_basis {
250 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
251 }
252 }
253 }
254 }
255
256 let f_prime = &s_half_inv * &fock * &s_half_inv;
257 let eigen = f_prime.symmetric_eigen();
258
259 let mut indices: Vec<usize> = (0..n_basis).collect();
261 indices.sort_by(|&a, &b| {
262 eigen.eigenvalues[a]
263 .partial_cmp(&eigen.eigenvalues[b])
264 .unwrap_or(std::cmp::Ordering::Equal)
265 });
266
267 for (new_idx, &old_idx) in indices.iter().enumerate() {
268 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
269 }
270
271 let c_prime = &eigen.eigenvectors;
273 let c_full = &s_half_inv * c_prime;
274
275 for new_idx in 0..n_basis {
277 let old_idx = indices[new_idx];
278 for i in 0..n_basis {
279 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
280 }
281 }
282
283 let mut new_density = DMatrix::zeros(n_basis, n_basis);
285 for i in 0..n_basis {
286 for j in 0..n_basis {
287 let mut val = 0.0;
288 for k in 0..n_occ.min(n_basis) {
289 val += coefficients[(i, k)] * coefficients[(j, k)];
290 }
291 new_density[(i, j)] = 2.0 * val;
292 }
293 }
294
295 let mut e_elec = 0.0;
297 for i in 0..n_basis {
298 for j in 0..n_basis {
299 e_elec += 0.5 * new_density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
300 }
301 }
302
303 if (e_elec - prev_energy).abs() < convergence_threshold && iter > 0 {
305 converged = true;
306 density = new_density;
307 break;
308 }
309
310 prev_energy = e_elec;
311
312 let mut g_mat = DMatrix::zeros(n_basis, n_basis);
316
317 for i in 0..n_basis {
319 let (atom_a, la, _ma) = basis_map[i];
320 let pa = get_pm3_params(elements[atom_a]).unwrap();
321 for j in 0..n_basis {
322 let (atom_b, lb, _mb) = basis_map[j];
323 if atom_a == atom_b {
324 let gij = if la == 0 && lb == 0 {
326 pa.gss
327 } else if (la == 0 && lb == 1) || (la == 1 && lb == 0) {
328 pa.gsp
329 } else if la == 1 && lb == 1 {
330 pa.gpp
331 } else {
332 0.0
333 };
334
335 g_mat[(i, i)] += new_density[(j, j)] * gij;
337 if i != j {
338 g_mat[(i, j)] -= 0.5 * new_density[(i, j)] * gij;
340 }
341 }
342 }
343
344 for b in 0..n_atoms {
346 if b == atom_a {
347 continue;
348 }
349 let r_bohr = distance_bohr(&positions[atom_a], &positions[b]);
350 let ev_per_hartree = 27.2114;
351 let gamma_ab = ev_per_hartree / r_bohr.max(0.5);
352
353 let mut p_b = 0.0;
355 for k in 0..n_basis {
356 if basis_map[k].0 == b {
357 p_b += new_density[(k, k)];
358 }
359 }
360 let pb_params = get_pm3_params(elements[b]).unwrap();
361 g_mat[(i, i)] += (p_b - pb_params.core_charge) * gamma_ab * 0.0; g_mat[(i, i)] += p_b * gamma_ab * 0.1; }
366 }
367
368 let damp = if iter < 5 { 0.5 } else { 0.3 };
370 density = &density * damp + &new_density * (1.0 - damp);
371
372 fock = &h_core + &g_mat;
374 }
375
376 let mut e_elec = 0.0;
378 for i in 0..n_basis {
379 for j in 0..n_basis {
380 e_elec += 0.5 * density[(i, j)] * (h_core[(i, j)] + fock[(i, j)]);
381 }
382 }
383
384 let total_energy = e_elec + e_nuc;
385
386 let mut e_atom_sum = 0.0;
388 let mut dhf_atom_sum = 0.0;
389 for &z in elements {
390 let p = get_pm3_params(z).unwrap();
391 e_atom_sum += if z == 1 {
393 p.uss * p.core_charge * 0.5
394 } else {
395 (p.uss + 3.0 * p.upp) * 0.25 * p.core_charge
396 };
397 dhf_atom_sum += p.heat_of_atomization;
398 }
399 let heat_of_formation = (total_energy - e_atom_sum) * EV_TO_KCAL + dhf_atom_sum;
400
401 let sp = &density * &s_mat;
403 let mut mulliken_charges = Vec::with_capacity(n_atoms);
404 for a in 0..n_atoms {
405 let pa = get_pm3_params(elements[a]).unwrap();
406 let mut pop = 0.0;
407 for i in 0..n_basis {
408 if basis_map[i].0 == a {
409 pop += sp[(i, i)];
410 }
411 }
412 mulliken_charges.push(pa.core_charge - pop);
413 }
414
415 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
416 let lumo_idx = n_occ.min(n_basis - 1);
417 let homo_energy = orbital_energies[homo_idx];
418 let lumo_energy = if n_occ < n_basis {
419 orbital_energies[lumo_idx]
420 } else {
421 homo_energy
422 };
423 let gap = if n_occ < n_basis {
424 lumo_energy - homo_energy
425 } else {
426 0.0
427 };
428
429 Ok(Pm3Result {
430 orbital_energies,
431 electronic_energy: e_elec,
432 nuclear_repulsion: e_nuc,
433 total_energy,
434 heat_of_formation,
435 n_basis,
436 n_electrons,
437 homo_energy,
438 lumo_energy,
439 gap,
440 mulliken_charges,
441 scf_iterations: scf_iter,
442 converged,
443 })
444}
445
446#[cfg(test)]
447mod tests {
448 use super::*;
449
450 #[test]
451 fn test_pm3_h2() {
452 let elements = [1u8, 1];
453 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
454 let result = solve_pm3(&elements, &positions).unwrap();
455 assert_eq!(result.n_basis, 2);
456 assert_eq!(result.n_electrons, 2);
457 assert!(result.total_energy.is_finite());
458 assert!(result.gap >= 0.0);
459 assert!((result.mulliken_charges[0] - result.mulliken_charges[1]).abs() < 0.01);
461 }
462
463 #[test]
464 fn test_pm3_water() {
465 let elements = [8u8, 1, 1];
466 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
467 let result = solve_pm3(&elements, &positions).unwrap();
468 assert_eq!(result.n_basis, 6); assert_eq!(result.n_electrons, 8);
470 assert!(result.total_energy.is_finite());
471 assert!(
472 result.gap > 0.0,
473 "Water should have a positive HOMO-LUMO gap"
474 );
475 assert!(
477 (result.mulliken_charges[0] - result.mulliken_charges[1]).abs() > 0.001,
478 "O and H charges should differ"
479 );
480 }
481
482 #[test]
483 fn test_pm3_methane() {
484 let elements = [6u8, 1, 1, 1, 1];
485 let positions = [
486 [0.0, 0.0, 0.0],
487 [0.629, 0.629, 0.629],
488 [-0.629, -0.629, 0.629],
489 [0.629, -0.629, -0.629],
490 [-0.629, 0.629, -0.629],
491 ];
492 let result = solve_pm3(&elements, &positions).unwrap();
493 assert_eq!(result.n_basis, 8); assert_eq!(result.n_electrons, 8);
495 assert!(result.total_energy.is_finite());
496 }
497
498 #[test]
499 fn test_pm3_unsupported_element() {
500 let elements = [26u8, 17]; let positions = [[0.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
502 assert!(solve_pm3(&elements, &positions).is_err());
503 }
504}