Skip to main content

sci_form/xtb/
solver.rs

1//! GFN0-xTB-inspired tight-binding solver.
2//!
3//! Implements a charge-self-consistent tight-binding scheme with
4//! repulsive pair potentials and Mulliken charge analysis.
5
6use super::params::{count_xtb_electrons, get_xtb_params, num_xtb_basis_functions};
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10/// Result of an xTB tight-binding calculation.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct XtbResult {
13    /// Orbital energies (eV), sorted ascending.
14    pub orbital_energies: Vec<f64>,
15    /// Electronic energy (eV).
16    pub electronic_energy: f64,
17    /// Repulsive energy (eV).
18    pub repulsive_energy: f64,
19    /// Total energy (eV) = electronic + repulsive.
20    pub total_energy: f64,
21    /// Number of basis functions.
22    pub n_basis: usize,
23    /// Number of electrons.
24    pub n_electrons: usize,
25    /// HOMO energy (eV).
26    pub homo_energy: f64,
27    /// LUMO energy (eV).
28    pub lumo_energy: f64,
29    /// HOMO-LUMO gap (eV).
30    pub gap: f64,
31    /// Mulliken charges from TB density.
32    pub mulliken_charges: Vec<f64>,
33    /// Number of SCC iterations.
34    pub scc_iterations: usize,
35    /// Whether SCC converged.
36    pub converged: bool,
37}
38
39pub(crate) const ANGSTROM_TO_BOHR: f64 = 1.0 / 0.529177;
40pub(crate) const EV_PER_HARTREE: f64 = 27.2114;
41
42/// Compute distance in bohr between two atoms.
43pub(crate) fn distance_bohr(a: &[f64; 3], b: &[f64; 3]) -> f64 {
44    let dx = (a[0] - b[0]) * ANGSTROM_TO_BOHR;
45    let dy = (a[1] - b[1]) * ANGSTROM_TO_BOHR;
46    let dz = (a[2] - b[2]) * ANGSTROM_TO_BOHR;
47    (dx * dx + dy * dy + dz * dz).sqrt()
48}
49
50/// Compute damped STO overlap integral (s-s approximation).
51pub(crate) fn sto_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
52    if r_bohr < 1e-10 {
53        return if (zeta_a - zeta_b).abs() < 1e-10 {
54            1.0
55        } else {
56            0.0
57        };
58    }
59    let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
60    (-p).exp() * (1.0 + p + p * p / 3.0)
61}
62
63/// Build basis map: (atom_index, l_quantum, m_offset).
64pub(crate) fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
65    let mut basis = Vec::new();
66    for (i, &z) in elements.iter().enumerate() {
67        let n = num_xtb_basis_functions(z);
68        if n >= 1 {
69            basis.push((i, 0, 0));
70        } // s
71        if n >= 4 {
72            basis.push((i, 1, 0)); // px
73            basis.push((i, 1, 1)); // py
74            basis.push((i, 1, 2)); // pz
75        }
76        if n >= 9 {
77            for m in 0..5u8 {
78                basis.push((i, 2, m));
79            } // d orbitals
80        }
81    }
82    basis
83}
84
85/// Run an xTB tight-binding calculation.
86///
87/// `elements`: atomic numbers.
88/// `positions`: Cartesian coordinates in Å.
89/// SCF state for xTB gradient computation and GFN1 shell SCC.
90pub(crate) struct XtbScfState {
91    pub density: DMatrix<f64>,
92    pub coefficients: DMatrix<f64>,
93    pub orbital_energies: Vec<f64>,
94    pub basis_map: Vec<(usize, u8, u8)>,
95    pub n_occ: usize,
96    pub charges: Vec<f64>,
97    pub h_diag: Vec<f64>,
98    pub overlap: DMatrix<f64>,
99    pub hamiltonian: DMatrix<f64>,
100    pub s_half_inv: DMatrix<f64>,
101}
102
103/// Run xTB calculation returning both result and internal SCF state.
104pub(crate) fn solve_xtb_with_state(
105    elements: &[u8],
106    positions: &[[f64; 3]],
107) -> Result<(XtbResult, XtbScfState), String> {
108    if elements.len() != positions.len() {
109        return Err(format!(
110            "elements ({}) and positions ({}) length mismatch",
111            elements.len(),
112            positions.len()
113        ));
114    }
115
116    for &z in elements {
117        if get_xtb_params(z).is_none() {
118            return Err(format!("xTB parameters not available for Z={}", z));
119        }
120    }
121
122    let n_atoms = elements.len();
123    let basis_map = build_basis_map(elements);
124    let n_basis = basis_map.len();
125    let n_electrons = count_xtb_electrons(elements);
126    let n_occ = n_electrons / 2;
127
128    if n_basis == 0 {
129        return Err("No basis functions".to_string());
130    }
131
132    // Build overlap matrix
133    let mut s_mat = DMatrix::zeros(n_basis, n_basis);
134    for i in 0..n_basis {
135        s_mat[(i, i)] = 1.0;
136        let (atom_a, la, _) = basis_map[i];
137        for j in (i + 1)..n_basis {
138            let (atom_b, lb, _) = basis_map[j];
139            if atom_a == atom_b {
140                continue;
141            }
142            let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
143            let pa = get_xtb_params(elements[atom_a]).unwrap();
144            let pb = get_xtb_params(elements[atom_b]).unwrap();
145            let za = match la {
146                0 => pa.zeta_s,
147                1 => pa.zeta_p,
148                _ => pa.zeta_d,
149            };
150            let zb = match lb {
151                0 => pb.zeta_s,
152                1 => pb.zeta_p,
153                _ => pb.zeta_d,
154            };
155            if za < 1e-10 || zb < 1e-10 {
156                continue;
157            }
158            // Shell-dependent overlap scaling factors from Grimme's GFN0 parametrization.
159            // s-s: full overlap; s-p: reduced due to angular mismatch; p-p: further reduced.
160            // d-orbital scaling follows similar attenuation pattern.
161            let scale = match (la, lb) {
162                (0, 0) => 1.0,           // s-s
163                (0, 1) | (1, 0) => 0.65, // s-p (angular mismatch)
164                (1, 1) => 0.55,          // p-p σ approximation
165                (0, 2) | (2, 0) => 0.40, // s-d
166                (1, 2) | (2, 1) => 0.35, // p-d
167                (2, 2) => 0.30,          // d-d
168                _ => 0.5,
169            };
170            let sij = sto_overlap(za, zb, r) * scale;
171            s_mat[(i, j)] = sij;
172            s_mat[(j, i)] = sij;
173        }
174    }
175
176    // Build Hamiltonian: H_ii = level energy, H_ij = Wolfsberg-Helmholtz
177    let mut h_mat = DMatrix::zeros(n_basis, n_basis);
178    for i in 0..n_basis {
179        let (atom_a, la, _) = basis_map[i];
180        let pa = get_xtb_params(elements[atom_a]).unwrap();
181        h_mat[(i, i)] = match la {
182            0 => pa.h_s,
183            1 => pa.h_p,
184            _ => pa.h_d,
185        };
186    }
187    for i in 0..n_basis {
188        for j in (i + 1)..n_basis {
189            let (atom_a, _, _) = basis_map[i];
190            let (atom_b, _, _) = basis_map[j];
191            if atom_a == atom_b {
192                continue;
193            }
194            let k_wh = 1.75;
195            let hij = 0.5 * k_wh * s_mat[(i, j)] * (h_mat[(i, i)] + h_mat[(j, j)]);
196            h_mat[(i, j)] = hij;
197            h_mat[(j, i)] = hij;
198        }
199    }
200
201    // Repulsive energy: pair potential with coordination-number damping.
202    // First pass: compute coordination numbers for each atom.
203    let coord_numbers: Vec<f64> = (0..n_atoms)
204        .map(|a| {
205            let pa = get_xtb_params(elements[a]).unwrap();
206            let mut cn = 0.0;
207            for b in 0..n_atoms {
208                if b == a {
209                    continue;
210                }
211                let pb = get_xtb_params(elements[b]).unwrap();
212                let dx = positions[a][0] - positions[b][0];
213                let dy = positions[a][1] - positions[b][1];
214                let dz = positions[a][2] - positions[b][2];
215                let r = (dx * dx + dy * dy + dz * dz).sqrt();
216                let r_ref = pa.r_cov + pb.r_cov;
217                // Fermi-type counting function
218                cn += 1.0 / (1.0 + (-16.0 * (r_ref / r - 1.0)).exp());
219            }
220            cn
221        })
222        .collect();
223
224    let mut e_rep = 0.0;
225    for a in 0..n_atoms {
226        let pa = get_xtb_params(elements[a]).unwrap();
227        for b in (a + 1)..n_atoms {
228            let pb = get_xtb_params(elements[b]).unwrap();
229            let r_ang = {
230                let dx = positions[a][0] - positions[b][0];
231                let dy = positions[a][1] - positions[b][1];
232                let dz = positions[a][2] - positions[b][2];
233                (dx * dx + dy * dy + dz * dz).sqrt()
234            };
235            if r_ang < 0.1 {
236                continue;
237            }
238            let r_ref = pa.r_cov + pb.r_cov;
239            // Short-range repulsive with coordination-number dependent scaling.
240            // Effective Z is reduced for highly-coordinated atoms.
241            let alpha = 6.0;
242            let cn_a = coord_numbers[a];
243            let cn_b = coord_numbers[b];
244            let z_eff_a = (pa.n_valence as f64) / (1.0 + 0.1 * cn_a);
245            let z_eff_b = (pb.n_valence as f64) / (1.0 + 0.1 * cn_b);
246            e_rep += z_eff_a * z_eff_b * EV_PER_HARTREE / (r_ang * ANGSTROM_TO_BOHR)
247                * (-alpha * (r_ang / r_ref - 1.0)).exp();
248        }
249    }
250
251    // SCC (self-consistent charges) loop
252    let max_iter = 50;
253    let convergence = 1e-6;
254    let mut charges = vec![0.0f64; n_atoms];
255    let mut orbital_energies = vec![0.0; n_basis];
256    let mut coefficients = DMatrix::zeros(n_basis, n_basis);
257    let mut converged = false;
258    let mut scc_iter = 0;
259    let mut prev_e_elec = 0.0;
260
261    // Löwdin S^{-1/2}
262    let s_eigen = s_mat.clone().symmetric_eigen();
263    let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
264    for k in 0..n_basis {
265        let val = s_eigen.eigenvalues[k];
266        if val > 1e-8 {
267            let inv_sqrt = 1.0 / val.sqrt();
268            let col = s_eigen.eigenvectors.column(k);
269            for i in 0..n_basis {
270                for j in 0..n_basis {
271                    s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
272                }
273            }
274        }
275    }
276
277    // Pre-compute atom-pair gamma matrix (GPU-accelerated when available)
278    let gamma_atoms = {
279        let mut gm = vec![vec![0.0f64; n_atoms]; n_atoms];
280
281        #[cfg(feature = "experimental-gpu")]
282        let gpu_ok = {
283            let eta_vec: Vec<f64> = (0..n_atoms)
284                .map(|a| get_xtb_params(elements[a]).unwrap().eta)
285                .collect();
286            let pos_bohr: Vec<[f64; 3]> = positions
287                .iter()
288                .map(|p| {
289                    [
290                        p[0] * 1.8897259886,
291                        p[1] * 1.8897259886,
292                        p[2] * 1.8897259886,
293                    ]
294                })
295                .collect();
296            if n_atoms >= 8 {
297                if let Ok(ctx) = crate::gpu::context::GpuContext::try_create() {
298                    if let Ok(gpu_gamma) =
299                        super::gpu::build_xtb_gamma_gpu(&ctx, &eta_vec, &pos_bohr)
300                    {
301                        for a in 0..n_atoms {
302                            for b in 0..n_atoms {
303                                gm[a][b] = gpu_gamma[(a, b)];
304                            }
305                        }
306                        true
307                    } else {
308                        false
309                    }
310                } else {
311                    false
312                }
313            } else {
314                false
315            }
316        };
317
318        #[cfg(not(feature = "experimental-gpu"))]
319        let gpu_ok = false;
320
321        if !gpu_ok {
322            for a in 0..n_atoms {
323                let pa = get_xtb_params(elements[a]).unwrap();
324                gm[a][a] = pa.eta; // self-interaction
325                for b in (a + 1)..n_atoms {
326                    let pb = get_xtb_params(elements[b]).unwrap();
327                    let r_bohr = distance_bohr(&positions[a], &positions[b]);
328                    let gamma =
329                        1.0 / ((1.0 / pa.eta + 1.0 / pb.eta).powi(2) + r_bohr.powi(2)).sqrt();
330                    gm[a][b] = gamma;
331                    gm[b][a] = gamma;
332                }
333            }
334        }
335        gm
336    };
337
338    for iter in 0..max_iter {
339        scc_iter = iter + 1;
340
341        // Build charge-shifted Hamiltonian using pre-computed gamma matrix
342        let mut h_scc = h_mat.clone();
343
344        #[cfg(feature = "parallel")]
345        {
346            use rayon::prelude::*;
347            let shifts: Vec<f64> = (0..n_basis)
348                .into_par_iter()
349                .map(|i| {
350                    let (atom_a, _, _) = basis_map[i];
351                    let mut shift = 0.0;
352                    for b in 0..n_atoms {
353                        if b == atom_a {
354                            continue;
355                        }
356                        shift += gamma_atoms[atom_a][b] * charges[b];
357                    }
358                    shift += gamma_atoms[atom_a][atom_a] * charges[atom_a];
359                    shift
360                })
361                .collect();
362            for (i, s) in shifts.into_iter().enumerate() {
363                h_scc[(i, i)] += s;
364            }
365        }
366
367        #[cfg(not(feature = "parallel"))]
368        {
369            for i in 0..n_basis {
370                let (atom_a, _, _) = basis_map[i];
371                let mut shift = 0.0;
372                for b in 0..n_atoms {
373                    if b == atom_a {
374                        continue;
375                    }
376                    shift += gamma_atoms[atom_a][b] * charges[b];
377                }
378                shift += gamma_atoms[atom_a][atom_a] * charges[atom_a];
379                h_scc[(i, i)] += shift;
380            }
381        }
382
383        // Solve HC = SCε via Löwdin
384        let f_prime = &s_half_inv * &h_scc * &s_half_inv;
385        let eigen = f_prime.symmetric_eigen();
386
387        let mut indices: Vec<usize> = (0..n_basis).collect();
388        indices.sort_by(|&a, &b| {
389            eigen.eigenvalues[a]
390                .partial_cmp(&eigen.eigenvalues[b])
391                .unwrap_or(std::cmp::Ordering::Equal)
392        });
393
394        for (new_idx, &old_idx) in indices.iter().enumerate() {
395            orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
396        }
397
398        let c_prime = &eigen.eigenvectors;
399        let c_full = &s_half_inv * c_prime;
400        for new_idx in 0..n_basis {
401            let old_idx = indices[new_idx];
402            for i in 0..n_basis {
403                coefficients[(i, new_idx)] = c_full[(i, old_idx)];
404            }
405        }
406
407        // Build density matrix
408        let mut density = DMatrix::zeros(n_basis, n_basis);
409        for i in 0..n_basis {
410            for j in 0..n_basis {
411                let mut val = 0.0;
412                for k in 0..n_occ.min(n_basis) {
413                    val += coefficients[(i, k)] * coefficients[(j, k)];
414                }
415                density[(i, j)] = 2.0 * val;
416            }
417        }
418
419        // Mulliken charges
420        let ps = &density * &s_mat;
421        let mut new_charges = Vec::with_capacity(n_atoms);
422        for a in 0..n_atoms {
423            let pa = get_xtb_params(elements[a]).unwrap();
424            let mut pop = 0.0;
425            for i in 0..n_basis {
426                if basis_map[i].0 == a {
427                    pop += ps[(i, i)];
428                }
429            }
430            new_charges.push(pa.n_valence as f64 - pop);
431        }
432
433        // Electronic energy
434        let mut e_elec = 0.0;
435        for i in 0..n_basis {
436            for j in 0..n_basis {
437                e_elec += 0.5 * density[(i, j)] * (h_mat[(i, j)] + h_scc[(i, j)]);
438            }
439        }
440
441        // Convergence: check both energy AND charge changes.
442        let max_dq: f64 = charges
443            .iter()
444            .zip(new_charges.iter())
445            .map(|(old, new)| (old - new).abs())
446            .fold(0.0, f64::max);
447        let energy_converged = (e_elec - prev_e_elec).abs() < convergence && iter > 0;
448        let charge_converged = max_dq < convergence * 100.0; // 1e-4 for charges
449        if energy_converged && charge_converged {
450            converged = true;
451            prev_e_elec = e_elec;
452            charges = new_charges;
453            break;
454        }
455        prev_e_elec = e_elec;
456
457        // Adaptive SCC charge damping: start conservative, reduce as charges stabilize.
458        let damp = if max_dq > 0.5 {
459            0.6 // Large charge oscillations: heavy damping
460        } else if max_dq > 0.1 {
461            0.4 // Moderate oscillations
462        } else {
463            0.2 // Near convergence: mostly new charges
464        };
465        for a in 0..n_atoms {
466            charges[a] = damp * charges[a] + (1.0 - damp) * new_charges[a];
467        }
468    }
469
470    // Final electronic energy from last density
471    let e_elec = prev_e_elec;
472    let total_energy = e_elec + e_rep;
473
474    let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
475    let lumo_idx = n_occ.min(n_basis - 1);
476    let homo_energy = orbital_energies[homo_idx];
477    let lumo_energy = if n_occ < n_basis {
478        orbital_energies[lumo_idx]
479    } else {
480        homo_energy
481    };
482    let gap = if n_occ < n_basis {
483        lumo_energy - homo_energy
484    } else {
485        0.0
486    };
487
488    // Save diagonal Hamiltonian for gradient
489    let h_diag: Vec<f64> = (0..n_basis).map(|i| h_mat[(i, i)]).collect();
490
491    let state = XtbScfState {
492        density: {
493            // Rebuild final density from coefficients
494            let mut d = DMatrix::zeros(n_basis, n_basis);
495            for i in 0..n_basis {
496                for j in 0..n_basis {
497                    let mut val = 0.0;
498                    for k in 0..n_occ.min(n_basis) {
499                        val += coefficients[(i, k)] * coefficients[(j, k)];
500                    }
501                    d[(i, j)] = 2.0 * val;
502                }
503            }
504            d
505        },
506        coefficients: coefficients.clone(),
507        orbital_energies: orbital_energies.clone(),
508        basis_map,
509        n_occ,
510        charges: charges.clone(),
511        h_diag,
512        overlap: s_mat,
513        hamiltonian: h_mat,
514        s_half_inv,
515    };
516
517    Ok((
518        XtbResult {
519            orbital_energies,
520            electronic_energy: e_elec,
521            repulsive_energy: e_rep,
522            total_energy,
523            n_basis,
524            n_electrons,
525            homo_energy,
526            lumo_energy,
527            gap,
528            mulliken_charges: charges,
529            scc_iterations: scc_iter,
530            converged,
531        },
532        state,
533    ))
534}
535
536/// Run an xTB tight-binding calculation.
537pub fn solve_xtb(elements: &[u8], positions: &[[f64; 3]]) -> Result<XtbResult, String> {
538    solve_xtb_with_state(elements, positions).map(|(r, _)| r)
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    #[test]
546    fn test_xtb_h2() {
547        let elements = [1u8, 1];
548        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
549        let result = solve_xtb(&elements, &positions).unwrap();
550        assert_eq!(result.n_basis, 2);
551        assert_eq!(result.n_electrons, 2);
552        assert!(result.total_energy.is_finite());
553        assert!(result.gap >= 0.0);
554    }
555
556    #[test]
557    fn test_xtb_water() {
558        let elements = [8u8, 1, 1];
559        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
560        let result = solve_xtb(&elements, &positions).unwrap();
561        assert_eq!(result.n_basis, 6);
562        assert_eq!(result.n_electrons, 8);
563        assert!(result.total_energy.is_finite());
564        assert!(result.gap > 0.0, "Water should have a positive gap");
565    }
566
567    #[test]
568    fn test_xtb_ferrocene_atom() {
569        // Just Fe atom — should work with s+p+d
570        let elements = [26u8];
571        let positions = [[0.0, 0.0, 0.0]];
572        let result = solve_xtb(&elements, &positions).unwrap();
573        assert_eq!(result.n_basis, 9); // s+p+d
574        assert_eq!(result.n_electrons, 8);
575    }
576
577    #[test]
578    fn test_xtb_unsupported() {
579        let elements = [92u8]; // uranium
580        let positions = [[0.0, 0.0, 0.0]];
581        assert!(solve_xtb(&elements, &positions).is_err());
582    }
583}