1use super::params::{count_xtb_electrons, get_xtb_params, num_xtb_basis_functions};
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct XtbResult {
13 pub orbital_energies: Vec<f64>,
15 pub electronic_energy: f64,
17 pub repulsive_energy: f64,
19 pub total_energy: f64,
21 pub n_basis: usize,
23 pub n_electrons: usize,
25 pub homo_energy: f64,
27 pub lumo_energy: f64,
29 pub gap: f64,
31 pub mulliken_charges: Vec<f64>,
33 pub scc_iterations: usize,
35 pub converged: bool,
37}
38
39pub(crate) const ANGSTROM_TO_BOHR: f64 = 1.889_725_988_6;
40pub(crate) const EV_PER_HARTREE: f64 = 27.211_385_05;
41
42pub(crate) fn distance_bohr(a: &[f64; 3], b: &[f64; 3]) -> f64 {
44 let dx = (a[0] - b[0]) * ANGSTROM_TO_BOHR;
45 let dy = (a[1] - b[1]) * ANGSTROM_TO_BOHR;
46 let dz = (a[2] - b[2]) * ANGSTROM_TO_BOHR;
47 (dx * dx + dy * dy + dz * dz).sqrt()
48}
49
50pub(crate) fn sto_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
52 if r_bohr < 1e-10 {
53 return if (zeta_a - zeta_b).abs() < 1e-10 {
54 1.0
55 } else {
56 0.0
57 };
58 }
59 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
60 (-p).exp() * (1.0 + p + p * p / 3.0)
61}
62
63pub(crate) fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
65 let mut basis = Vec::new();
66 for (i, &z) in elements.iter().enumerate() {
67 let n = num_xtb_basis_functions(z);
68 if n >= 1 {
69 basis.push((i, 0, 0));
70 } if n >= 4 {
72 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
76 if n >= 9 {
77 for m in 0..5u8 {
78 basis.push((i, 2, m));
79 } }
81 }
82 basis
83}
84
85pub(crate) struct XtbScfState {
91 pub density: DMatrix<f64>,
92 pub coefficients: DMatrix<f64>,
93 pub orbital_energies: Vec<f64>,
94 pub basis_map: Vec<(usize, u8, u8)>,
95 pub n_occ: usize,
96 pub charges: Vec<f64>,
97 pub h_diag: Vec<f64>,
98 pub overlap: DMatrix<f64>,
99 pub hamiltonian: DMatrix<f64>,
100 pub s_half_inv: DMatrix<f64>,
101}
102
103pub(crate) fn solve_xtb_with_state(
105 elements: &[u8],
106 positions: &[[f64; 3]],
107) -> Result<(XtbResult, XtbScfState), String> {
108 if elements.len() != positions.len() {
109 return Err(format!(
110 "elements ({}) and positions ({}) length mismatch",
111 elements.len(),
112 positions.len()
113 ));
114 }
115
116 for &z in elements {
117 if get_xtb_params(z).is_none() {
118 return Err(format!("xTB parameters not available for Z={}", z));
119 }
120 }
121
122 let n_atoms = elements.len();
123 let basis_map = build_basis_map(elements);
124 let n_basis = basis_map.len();
125 let n_electrons = count_xtb_electrons(elements);
126 let n_occ = n_electrons / 2;
127
128 if n_basis == 0 {
129 return Err("No basis functions".to_string());
130 }
131
132 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
134 for i in 0..n_basis {
135 s_mat[(i, i)] = 1.0;
136 let (atom_a, la, _) = basis_map[i];
137 for j in (i + 1)..n_basis {
138 let (atom_b, lb, _) = basis_map[j];
139 if atom_a == atom_b {
140 continue;
141 }
142 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
143 let pa = get_xtb_params(elements[atom_a]).unwrap();
144 let pb = get_xtb_params(elements[atom_b]).unwrap();
145 let za = match la {
146 0 => pa.zeta_s,
147 1 => pa.zeta_p,
148 _ => pa.zeta_d,
149 };
150 let zb = match lb {
151 0 => pb.zeta_s,
152 1 => pb.zeta_p,
153 _ => pb.zeta_d,
154 };
155 if za < 1e-10 || zb < 1e-10 {
156 continue;
157 }
158 let scale = match (la, lb) {
162 (0, 0) => 1.0, (0, 1) | (1, 0) => 0.65, (1, 1) => 0.55, (0, 2) | (2, 0) => 0.40, (1, 2) | (2, 1) => 0.35, (2, 2) => 0.30, _ => 0.5,
169 };
170 let sij = sto_overlap(za, zb, r) * scale;
171 s_mat[(i, j)] = sij;
172 s_mat[(j, i)] = sij;
173 }
174 }
175
176 let mut h_mat = DMatrix::zeros(n_basis, n_basis);
178 for i in 0..n_basis {
179 let (atom_a, la, _) = basis_map[i];
180 let pa = get_xtb_params(elements[atom_a]).unwrap();
181 h_mat[(i, i)] = match la {
182 0 => pa.h_s,
183 1 => pa.h_p,
184 _ => pa.h_d,
185 };
186 }
187 for i in 0..n_basis {
188 for j in (i + 1)..n_basis {
189 let (atom_a, _, _) = basis_map[i];
190 let (atom_b, _, _) = basis_map[j];
191 if atom_a == atom_b {
192 continue;
193 }
194 let k_wh = 1.75;
195 let hij = 0.5 * k_wh * s_mat[(i, j)] * (h_mat[(i, i)] + h_mat[(j, j)]);
196 h_mat[(i, j)] = hij;
197 h_mat[(j, i)] = hij;
198 }
199 }
200
201 let coord_numbers: Vec<f64> = (0..n_atoms)
204 .map(|a| {
205 let pa = get_xtb_params(elements[a]).unwrap();
206 let mut cn = 0.0;
207 for b in 0..n_atoms {
208 if b == a {
209 continue;
210 }
211 let pb = get_xtb_params(elements[b]).unwrap();
212 let dx = positions[a][0] - positions[b][0];
213 let dy = positions[a][1] - positions[b][1];
214 let dz = positions[a][2] - positions[b][2];
215 let r = (dx * dx + dy * dy + dz * dz).sqrt();
216 let r_ref = pa.r_cov + pb.r_cov;
217 cn += 1.0 / (1.0 + (-16.0 * (r_ref / r - 1.0)).exp());
219 }
220 cn
221 })
222 .collect();
223
224 let mut e_rep = 0.0;
225 for a in 0..n_atoms {
226 let pa = get_xtb_params(elements[a]).unwrap();
227 for b in (a + 1)..n_atoms {
228 let pb = get_xtb_params(elements[b]).unwrap();
229 let r_ang = {
230 let dx = positions[a][0] - positions[b][0];
231 let dy = positions[a][1] - positions[b][1];
232 let dz = positions[a][2] - positions[b][2];
233 (dx * dx + dy * dy + dz * dz).sqrt()
234 };
235 if r_ang < 0.1 {
236 continue;
237 }
238 let r_ref = pa.r_cov + pb.r_cov;
239 let alpha = 6.0;
242 let cn_a = coord_numbers[a];
243 let cn_b = coord_numbers[b];
244 let z_eff_a = (pa.n_valence as f64) / (1.0 + 0.1 * cn_a);
245 let z_eff_b = (pb.n_valence as f64) / (1.0 + 0.1 * cn_b);
246 e_rep += z_eff_a * z_eff_b * EV_PER_HARTREE / (r_ang * ANGSTROM_TO_BOHR)
247 * (-alpha * (r_ang / r_ref - 1.0)).exp();
248 }
249 }
250
251 let max_iter = 250;
253 let convergence = 1e-6;
254 let mut charges = vec![0.0f64; n_atoms];
255 let mut orbital_energies = vec![0.0; n_basis];
256 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
257 let mut converged = false;
258 let mut scc_iter = 0;
259 let mut prev_e_elec = 0.0;
260
261 let mut mixer = super::broyden::BroydenMixer::new(n_atoms, 15, 0.4);
263
264 let s_eigen = s_mat.clone().symmetric_eigen();
266 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
267 for k in 0..n_basis {
268 let val = s_eigen.eigenvalues[k];
269 if val > 1e-8 {
270 let inv_sqrt = 1.0 / val.sqrt();
271 let col = s_eigen.eigenvectors.column(k);
272 for i in 0..n_basis {
273 for j in 0..n_basis {
274 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
275 }
276 }
277 }
278 }
279
280 let gamma_atoms = {
282 let mut gm = vec![vec![0.0f64; n_atoms]; n_atoms];
283
284 #[cfg(feature = "experimental-gpu")]
285 let gpu_ok = {
286 let eta_vec: Vec<f64> = (0..n_atoms)
287 .map(|a| get_xtb_params(elements[a]).unwrap().eta)
288 .collect();
289 let pos_bohr: Vec<[f64; 3]> = positions
290 .iter()
291 .map(|p| {
292 [
293 p[0] * ANGSTROM_TO_BOHR,
294 p[1] * ANGSTROM_TO_BOHR,
295 p[2] * ANGSTROM_TO_BOHR,
296 ]
297 })
298 .collect();
299 if n_atoms >= 8 {
300 if let Ok(ctx) = crate::gpu::context::GpuContext::try_create() {
301 if let Ok(gpu_gamma) =
302 super::gpu::build_xtb_gamma_gpu(&ctx, &eta_vec, &pos_bohr)
303 {
304 for a in 0..n_atoms {
305 for b in 0..n_atoms {
306 gm[a][b] = gpu_gamma[(a, b)];
307 }
308 }
309 true
310 } else {
311 false
312 }
313 } else {
314 false
315 }
316 } else {
317 false
318 }
319 };
320
321 #[cfg(not(feature = "experimental-gpu"))]
322 let gpu_ok = false;
323
324 if !gpu_ok {
325 for a in 0..n_atoms {
326 let pa = get_xtb_params(elements[a]).unwrap();
327 gm[a][a] = pa.eta; for b in (a + 1)..n_atoms {
329 let pb = get_xtb_params(elements[b]).unwrap();
330 let r_bohr = distance_bohr(&positions[a], &positions[b]);
331 let eta_a_ha = pa.eta / EV_PER_HARTREE;
334 let eta_b_ha = pb.eta / EV_PER_HARTREE;
335 let eta_avg_ha = 0.5 * (eta_a_ha + eta_b_ha);
336 let gamma_ha = 1.0 / (r_bohr.powi(2) + eta_avg_ha.powi(-2)).sqrt();
337 let gamma = gamma_ha * EV_PER_HARTREE;
338 gm[a][b] = gamma;
339 gm[b][a] = gamma;
340 }
341 }
342 }
343 gm
344 };
345
346 for iter in 0..max_iter {
347 scc_iter = iter + 1;
348
349 mixer.set(&charges);
351
352 let mut h_scc = h_mat.clone();
355 for i in 0..n_basis {
356 let atom_a = basis_map[i].0;
357 let mut shift = 0.0;
358 for b in 0..n_atoms {
359 shift += gamma_atoms[atom_a][b] * charges[b];
360 }
361 h_scc[(i, i)] -= shift;
362 }
363
364 let f_prime = &s_half_inv * &h_scc * &s_half_inv;
366 let eigen = f_prime.symmetric_eigen();
367
368 let mut indices: Vec<usize> = (0..n_basis).collect();
369 indices.sort_by(|&a, &b| {
370 eigen.eigenvalues[a]
371 .partial_cmp(&eigen.eigenvalues[b])
372 .unwrap_or(std::cmp::Ordering::Equal)
373 });
374
375 for (new_idx, &old_idx) in indices.iter().enumerate() {
376 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
377 }
378
379 let c_prime = &eigen.eigenvectors;
380 let c_full = &s_half_inv * c_prime;
381 for new_idx in 0..n_basis {
382 let old_idx = indices[new_idx];
383 for i in 0..n_basis {
384 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
385 }
386 }
387
388 let mut density = DMatrix::zeros(n_basis, n_basis);
390 for i in 0..n_basis {
391 for j in 0..n_basis {
392 let mut val = 0.0;
393 for k in 0..n_occ.min(n_basis) {
394 val += coefficients[(i, k)] * coefficients[(j, k)];
395 }
396 density[(i, j)] = 2.0 * val;
397 }
398 }
399
400 let ps = &density * &s_mat;
402 let mut new_charges = Vec::with_capacity(n_atoms);
403 for a in 0..n_atoms {
404 let pa = get_xtb_params(elements[a]).unwrap();
405 let mut pop = 0.0;
406 for i in 0..n_basis {
407 if basis_map[i].0 == a {
408 pop += ps[(i, i)];
409 }
410 }
411 new_charges.push(pa.n_valence as f64 - pop);
412 }
413
414 let mut e_elec = 0.0;
416 for i in 0..n_basis {
417 for j in 0..n_basis {
418 e_elec += 0.5 * density[(i, j)] * (h_mat[(i, j)] + h_scc[(i, j)]);
419 }
420 }
421
422 let de = (e_elec - prev_e_elec).abs();
424 if de < convergence && iter > 0 {
425 converged = true;
426 prev_e_elec = e_elec;
427 charges = new_charges;
428 break;
429 }
430 prev_e_elec = e_elec;
431
432 mixer.diff(&new_charges);
434 if iter > 0 {
435 let _ = mixer.step();
436 }
437 mixer.get(&mut charges);
438 }
439
440 let e_elec = prev_e_elec;
442 let total_energy = e_elec + e_rep;
443
444 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
445 let lumo_idx = n_occ.min(n_basis - 1);
446 let homo_energy = orbital_energies[homo_idx];
447 let lumo_energy = if n_occ < n_basis {
448 orbital_energies[lumo_idx]
449 } else {
450 homo_energy
451 };
452 let gap = if n_occ < n_basis {
453 lumo_energy - homo_energy
454 } else {
455 0.0
456 };
457
458 let h_diag: Vec<f64> = (0..n_basis).map(|i| h_mat[(i, i)]).collect();
460
461 let state = XtbScfState {
462 density: {
463 let mut d = DMatrix::zeros(n_basis, n_basis);
465 for i in 0..n_basis {
466 for j in 0..n_basis {
467 let mut val = 0.0;
468 for k in 0..n_occ.min(n_basis) {
469 val += coefficients[(i, k)] * coefficients[(j, k)];
470 }
471 d[(i, j)] = 2.0 * val;
472 }
473 }
474 d
475 },
476 coefficients: coefficients.clone(),
477 orbital_energies: orbital_energies.clone(),
478 basis_map,
479 n_occ,
480 charges: charges.clone(),
481 h_diag,
482 overlap: s_mat,
483 hamiltonian: h_mat,
484 s_half_inv,
485 };
486
487 Ok((
488 XtbResult {
489 orbital_energies,
490 electronic_energy: e_elec,
491 repulsive_energy: e_rep,
492 total_energy,
493 n_basis,
494 n_electrons,
495 homo_energy,
496 lumo_energy,
497 gap,
498 mulliken_charges: charges,
499 scc_iterations: scc_iter,
500 converged,
501 },
502 state,
503 ))
504}
505
506pub fn solve_xtb(elements: &[u8], positions: &[[f64; 3]]) -> Result<XtbResult, String> {
508 solve_xtb_with_state(elements, positions).map(|(r, _)| r)
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514
515 #[test]
516 fn test_xtb_h2() {
517 let elements = [1u8, 1];
518 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
519 let result = solve_xtb(&elements, &positions).unwrap();
520 assert_eq!(result.n_basis, 2);
521 assert_eq!(result.n_electrons, 2);
522 assert!(result.total_energy.is_finite());
523 assert!(result.gap >= 0.0);
524 }
525
526 #[test]
527 fn test_xtb_water() {
528 let elements = [8u8, 1, 1];
529 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
530 let result = solve_xtb(&elements, &positions).unwrap();
531 assert_eq!(result.n_basis, 6);
532 assert_eq!(result.n_electrons, 8);
533 assert!(result.total_energy.is_finite());
534 assert!(result.gap > 0.0, "Water should have a positive gap");
535 }
536
537 #[test]
538 fn test_xtb_ferrocene_atom() {
539 let elements = [26u8];
541 let positions = [[0.0, 0.0, 0.0]];
542 let result = solve_xtb(&elements, &positions).unwrap();
543 assert_eq!(result.n_basis, 9); assert_eq!(result.n_electrons, 8);
545 }
546
547 #[test]
548 fn test_xtb_unsupported() {
549 let elements = [92u8]; let positions = [[0.0, 0.0, 0.0]];
551 assert!(solve_xtb(&elements, &positions).is_err());
552 }
553}