1use super::params::{count_xtb_electrons, get_xtb_params, num_xtb_basis_functions};
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct XtbResult {
13 pub orbital_energies: Vec<f64>,
15 pub electronic_energy: f64,
17 pub repulsive_energy: f64,
19 pub total_energy: f64,
21 pub n_basis: usize,
23 pub n_electrons: usize,
25 pub homo_energy: f64,
27 pub lumo_energy: f64,
29 pub gap: f64,
31 pub mulliken_charges: Vec<f64>,
33 pub scc_iterations: usize,
35 pub converged: bool,
37}
38
39pub(crate) const ANGSTROM_TO_BOHR: f64 = 1.0 / 0.529177;
40pub(crate) const EV_PER_HARTREE: f64 = 27.2114;
41
42pub(crate) fn distance_bohr(a: &[f64; 3], b: &[f64; 3]) -> f64 {
44 let dx = (a[0] - b[0]) * ANGSTROM_TO_BOHR;
45 let dy = (a[1] - b[1]) * ANGSTROM_TO_BOHR;
46 let dz = (a[2] - b[2]) * ANGSTROM_TO_BOHR;
47 (dx * dx + dy * dy + dz * dz).sqrt()
48}
49
50pub(crate) fn sto_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
52 if r_bohr < 1e-10 {
53 return if (zeta_a - zeta_b).abs() < 1e-10 {
54 1.0
55 } else {
56 0.0
57 };
58 }
59 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
60 (-p).exp() * (1.0 + p + p * p / 3.0)
61}
62
63pub(crate) fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
65 let mut basis = Vec::new();
66 for (i, &z) in elements.iter().enumerate() {
67 let n = num_xtb_basis_functions(z);
68 if n >= 1 {
69 basis.push((i, 0, 0));
70 } if n >= 4 {
72 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
76 if n >= 9 {
77 for m in 0..5u8 {
78 basis.push((i, 2, m));
79 } }
81 }
82 basis
83}
84
85pub(crate) struct XtbScfState {
91 pub density: DMatrix<f64>,
92 pub coefficients: DMatrix<f64>,
93 pub orbital_energies: Vec<f64>,
94 pub basis_map: Vec<(usize, u8, u8)>,
95 pub n_occ: usize,
96 pub charges: Vec<f64>,
97 pub h_diag: Vec<f64>,
98 pub overlap: DMatrix<f64>,
99 pub hamiltonian: DMatrix<f64>,
100 pub s_half_inv: DMatrix<f64>,
101}
102
103pub(crate) fn solve_xtb_with_state(
105 elements: &[u8],
106 positions: &[[f64; 3]],
107) -> Result<(XtbResult, XtbScfState), String> {
108 if elements.len() != positions.len() {
109 return Err(format!(
110 "elements ({}) and positions ({}) length mismatch",
111 elements.len(),
112 positions.len()
113 ));
114 }
115
116 for &z in elements {
117 if get_xtb_params(z).is_none() {
118 return Err(format!("xTB parameters not available for Z={}", z));
119 }
120 }
121
122 let n_atoms = elements.len();
123 let basis_map = build_basis_map(elements);
124 let n_basis = basis_map.len();
125 let n_electrons = count_xtb_electrons(elements);
126 let n_occ = n_electrons / 2;
127
128 if n_basis == 0 {
129 return Err("No basis functions".to_string());
130 }
131
132 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
134 for i in 0..n_basis {
135 s_mat[(i, i)] = 1.0;
136 let (atom_a, la, _) = basis_map[i];
137 for j in (i + 1)..n_basis {
138 let (atom_b, lb, _) = basis_map[j];
139 if atom_a == atom_b {
140 continue;
141 }
142 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
143 let pa = get_xtb_params(elements[atom_a]).unwrap();
144 let pb = get_xtb_params(elements[atom_b]).unwrap();
145 let za = match la {
146 0 => pa.zeta_s,
147 1 => pa.zeta_p,
148 _ => pa.zeta_d,
149 };
150 let zb = match lb {
151 0 => pb.zeta_s,
152 1 => pb.zeta_p,
153 _ => pb.zeta_d,
154 };
155 if za < 1e-10 || zb < 1e-10 {
156 continue;
157 }
158 let scale = match (la, lb) {
162 (0, 0) => 1.0, (0, 1) | (1, 0) => 0.65, (1, 1) => 0.55, (0, 2) | (2, 0) => 0.40, (1, 2) | (2, 1) => 0.35, (2, 2) => 0.30, _ => 0.5,
169 };
170 let sij = sto_overlap(za, zb, r) * scale;
171 s_mat[(i, j)] = sij;
172 s_mat[(j, i)] = sij;
173 }
174 }
175
176 let mut h_mat = DMatrix::zeros(n_basis, n_basis);
178 for i in 0..n_basis {
179 let (atom_a, la, _) = basis_map[i];
180 let pa = get_xtb_params(elements[atom_a]).unwrap();
181 h_mat[(i, i)] = match la {
182 0 => pa.h_s,
183 1 => pa.h_p,
184 _ => pa.h_d,
185 };
186 }
187 for i in 0..n_basis {
188 for j in (i + 1)..n_basis {
189 let (atom_a, _, _) = basis_map[i];
190 let (atom_b, _, _) = basis_map[j];
191 if atom_a == atom_b {
192 continue;
193 }
194 let k_wh = 1.75;
195 let hij = 0.5 * k_wh * s_mat[(i, j)] * (h_mat[(i, i)] + h_mat[(j, j)]);
196 h_mat[(i, j)] = hij;
197 h_mat[(j, i)] = hij;
198 }
199 }
200
201 let coord_numbers: Vec<f64> = (0..n_atoms)
204 .map(|a| {
205 let pa = get_xtb_params(elements[a]).unwrap();
206 let mut cn = 0.0;
207 for b in 0..n_atoms {
208 if b == a {
209 continue;
210 }
211 let pb = get_xtb_params(elements[b]).unwrap();
212 let dx = positions[a][0] - positions[b][0];
213 let dy = positions[a][1] - positions[b][1];
214 let dz = positions[a][2] - positions[b][2];
215 let r = (dx * dx + dy * dy + dz * dz).sqrt();
216 let r_ref = pa.r_cov + pb.r_cov;
217 cn += 1.0 / (1.0 + (-16.0 * (r_ref / r - 1.0)).exp());
219 }
220 cn
221 })
222 .collect();
223
224 let mut e_rep = 0.0;
225 for a in 0..n_atoms {
226 let pa = get_xtb_params(elements[a]).unwrap();
227 for b in (a + 1)..n_atoms {
228 let pb = get_xtb_params(elements[b]).unwrap();
229 let r_ang = {
230 let dx = positions[a][0] - positions[b][0];
231 let dy = positions[a][1] - positions[b][1];
232 let dz = positions[a][2] - positions[b][2];
233 (dx * dx + dy * dy + dz * dz).sqrt()
234 };
235 if r_ang < 0.1 {
236 continue;
237 }
238 let r_ref = pa.r_cov + pb.r_cov;
239 let alpha = 6.0;
242 let cn_a = coord_numbers[a];
243 let cn_b = coord_numbers[b];
244 let z_eff_a = (pa.n_valence as f64) / (1.0 + 0.1 * cn_a);
245 let z_eff_b = (pb.n_valence as f64) / (1.0 + 0.1 * cn_b);
246 e_rep += z_eff_a * z_eff_b * EV_PER_HARTREE / (r_ang * ANGSTROM_TO_BOHR)
247 * (-alpha * (r_ang / r_ref - 1.0)).exp();
248 }
249 }
250
251 let max_iter = 50;
253 let convergence = 1e-6;
254 let mut charges = vec![0.0f64; n_atoms];
255 let mut orbital_energies = vec![0.0; n_basis];
256 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
257 let mut converged = false;
258 let mut scc_iter = 0;
259 let mut prev_e_elec = 0.0;
260
261 let s_eigen = s_mat.clone().symmetric_eigen();
263 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
264 for k in 0..n_basis {
265 let val = s_eigen.eigenvalues[k];
266 if val > 1e-8 {
267 let inv_sqrt = 1.0 / val.sqrt();
268 let col = s_eigen.eigenvectors.column(k);
269 for i in 0..n_basis {
270 for j in 0..n_basis {
271 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
272 }
273 }
274 }
275 }
276
277 let gamma_atoms = {
279 let mut gm = vec![vec![0.0f64; n_atoms]; n_atoms];
280
281 #[cfg(feature = "experimental-gpu")]
282 let gpu_ok = {
283 let eta_vec: Vec<f64> = (0..n_atoms)
284 .map(|a| get_xtb_params(elements[a]).unwrap().eta)
285 .collect();
286 let pos_bohr: Vec<[f64; 3]> = positions
287 .iter()
288 .map(|p| {
289 [
290 p[0] * 1.8897259886,
291 p[1] * 1.8897259886,
292 p[2] * 1.8897259886,
293 ]
294 })
295 .collect();
296 if n_atoms >= 8 {
297 if let Ok(ctx) = crate::gpu::context::GpuContext::try_create() {
298 if let Ok(gpu_gamma) =
299 super::gpu::build_xtb_gamma_gpu(&ctx, &eta_vec, &pos_bohr)
300 {
301 for a in 0..n_atoms {
302 for b in 0..n_atoms {
303 gm[a][b] = gpu_gamma[(a, b)];
304 }
305 }
306 true
307 } else {
308 false
309 }
310 } else {
311 false
312 }
313 } else {
314 false
315 }
316 };
317
318 #[cfg(not(feature = "experimental-gpu"))]
319 let gpu_ok = false;
320
321 if !gpu_ok {
322 for a in 0..n_atoms {
323 let pa = get_xtb_params(elements[a]).unwrap();
324 gm[a][a] = pa.eta; for b in (a + 1)..n_atoms {
326 let pb = get_xtb_params(elements[b]).unwrap();
327 let r_bohr = distance_bohr(&positions[a], &positions[b]);
328 let gamma =
329 1.0 / ((1.0 / pa.eta + 1.0 / pb.eta).powi(2) + r_bohr.powi(2)).sqrt();
330 gm[a][b] = gamma;
331 gm[b][a] = gamma;
332 }
333 }
334 }
335 gm
336 };
337
338 for iter in 0..max_iter {
339 scc_iter = iter + 1;
340
341 let mut h_scc = h_mat.clone();
343
344 #[cfg(feature = "parallel")]
345 {
346 use rayon::prelude::*;
347 let shifts: Vec<f64> = (0..n_basis)
348 .into_par_iter()
349 .map(|i| {
350 let (atom_a, _, _) = basis_map[i];
351 let mut shift = 0.0;
352 for b in 0..n_atoms {
353 if b == atom_a {
354 continue;
355 }
356 shift += gamma_atoms[atom_a][b] * charges[b];
357 }
358 shift += gamma_atoms[atom_a][atom_a] * charges[atom_a];
359 shift
360 })
361 .collect();
362 for (i, s) in shifts.into_iter().enumerate() {
363 h_scc[(i, i)] += s;
364 }
365 }
366
367 #[cfg(not(feature = "parallel"))]
368 {
369 for i in 0..n_basis {
370 let (atom_a, _, _) = basis_map[i];
371 let mut shift = 0.0;
372 for b in 0..n_atoms {
373 if b == atom_a {
374 continue;
375 }
376 shift += gamma_atoms[atom_a][b] * charges[b];
377 }
378 shift += gamma_atoms[atom_a][atom_a] * charges[atom_a];
379 h_scc[(i, i)] += shift;
380 }
381 }
382
383 let f_prime = &s_half_inv * &h_scc * &s_half_inv;
385 let eigen = f_prime.symmetric_eigen();
386
387 let mut indices: Vec<usize> = (0..n_basis).collect();
388 indices.sort_by(|&a, &b| {
389 eigen.eigenvalues[a]
390 .partial_cmp(&eigen.eigenvalues[b])
391 .unwrap_or(std::cmp::Ordering::Equal)
392 });
393
394 for (new_idx, &old_idx) in indices.iter().enumerate() {
395 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
396 }
397
398 let c_prime = &eigen.eigenvectors;
399 let c_full = &s_half_inv * c_prime;
400 for new_idx in 0..n_basis {
401 let old_idx = indices[new_idx];
402 for i in 0..n_basis {
403 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
404 }
405 }
406
407 let mut density = DMatrix::zeros(n_basis, n_basis);
409 for i in 0..n_basis {
410 for j in 0..n_basis {
411 let mut val = 0.0;
412 for k in 0..n_occ.min(n_basis) {
413 val += coefficients[(i, k)] * coefficients[(j, k)];
414 }
415 density[(i, j)] = 2.0 * val;
416 }
417 }
418
419 let ps = &density * &s_mat;
421 let mut new_charges = Vec::with_capacity(n_atoms);
422 for a in 0..n_atoms {
423 let pa = get_xtb_params(elements[a]).unwrap();
424 let mut pop = 0.0;
425 for i in 0..n_basis {
426 if basis_map[i].0 == a {
427 pop += ps[(i, i)];
428 }
429 }
430 new_charges.push(pa.n_valence as f64 - pop);
431 }
432
433 let mut e_elec = 0.0;
435 for i in 0..n_basis {
436 for j in 0..n_basis {
437 e_elec += 0.5 * density[(i, j)] * (h_mat[(i, j)] + h_scc[(i, j)]);
438 }
439 }
440
441 let max_dq: f64 = charges
443 .iter()
444 .zip(new_charges.iter())
445 .map(|(old, new)| (old - new).abs())
446 .fold(0.0, f64::max);
447 let energy_converged = (e_elec - prev_e_elec).abs() < convergence && iter > 0;
448 let charge_converged = max_dq < convergence * 100.0; if energy_converged && charge_converged {
450 converged = true;
451 prev_e_elec = e_elec;
452 charges = new_charges;
453 break;
454 }
455 prev_e_elec = e_elec;
456
457 let damp = if max_dq > 0.5 {
459 0.6 } else if max_dq > 0.1 {
461 0.4 } else {
463 0.2 };
465 for a in 0..n_atoms {
466 charges[a] = damp * charges[a] + (1.0 - damp) * new_charges[a];
467 }
468 }
469
470 let e_elec = prev_e_elec;
472 let total_energy = e_elec + e_rep;
473
474 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
475 let lumo_idx = n_occ.min(n_basis - 1);
476 let homo_energy = orbital_energies[homo_idx];
477 let lumo_energy = if n_occ < n_basis {
478 orbital_energies[lumo_idx]
479 } else {
480 homo_energy
481 };
482 let gap = if n_occ < n_basis {
483 lumo_energy - homo_energy
484 } else {
485 0.0
486 };
487
488 let h_diag: Vec<f64> = (0..n_basis).map(|i| h_mat[(i, i)]).collect();
490
491 let state = XtbScfState {
492 density: {
493 let mut d = DMatrix::zeros(n_basis, n_basis);
495 for i in 0..n_basis {
496 for j in 0..n_basis {
497 let mut val = 0.0;
498 for k in 0..n_occ.min(n_basis) {
499 val += coefficients[(i, k)] * coefficients[(j, k)];
500 }
501 d[(i, j)] = 2.0 * val;
502 }
503 }
504 d
505 },
506 coefficients: coefficients.clone(),
507 orbital_energies: orbital_energies.clone(),
508 basis_map,
509 n_occ,
510 charges: charges.clone(),
511 h_diag,
512 overlap: s_mat,
513 hamiltonian: h_mat,
514 s_half_inv,
515 };
516
517 Ok((
518 XtbResult {
519 orbital_energies,
520 electronic_energy: e_elec,
521 repulsive_energy: e_rep,
522 total_energy,
523 n_basis,
524 n_electrons,
525 homo_energy,
526 lumo_energy,
527 gap,
528 mulliken_charges: charges,
529 scc_iterations: scc_iter,
530 converged,
531 },
532 state,
533 ))
534}
535
536pub fn solve_xtb(elements: &[u8], positions: &[[f64; 3]]) -> Result<XtbResult, String> {
538 solve_xtb_with_state(elements, positions).map(|(r, _)| r)
539}
540
541#[cfg(test)]
542mod tests {
543 use super::*;
544
545 #[test]
546 fn test_xtb_h2() {
547 let elements = [1u8, 1];
548 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
549 let result = solve_xtb(&elements, &positions).unwrap();
550 assert_eq!(result.n_basis, 2);
551 assert_eq!(result.n_electrons, 2);
552 assert!(result.total_energy.is_finite());
553 assert!(result.gap >= 0.0);
554 }
555
556 #[test]
557 fn test_xtb_water() {
558 let elements = [8u8, 1, 1];
559 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
560 let result = solve_xtb(&elements, &positions).unwrap();
561 assert_eq!(result.n_basis, 6);
562 assert_eq!(result.n_electrons, 8);
563 assert!(result.total_energy.is_finite());
564 assert!(result.gap > 0.0, "Water should have a positive gap");
565 }
566
567 #[test]
568 fn test_xtb_ferrocene_atom() {
569 let elements = [26u8];
571 let positions = [[0.0, 0.0, 0.0]];
572 let result = solve_xtb(&elements, &positions).unwrap();
573 assert_eq!(result.n_basis, 9); assert_eq!(result.n_electrons, 8);
575 }
576
577 #[test]
578 fn test_xtb_unsupported() {
579 let elements = [92u8]; let positions = [[0.0, 0.0, 0.0]];
581 assert!(solve_xtb(&elements, &positions).is_err());
582 }
583}