Skip to main content

sci_form/xtb/
gradients.rs

1//! Analytical gradients for GFN0-xTB tight-binding calculations.
2//!
3//! Computes dE/dR for each atom using the converged SCC density and
4//! Hellmann-Feynman + Pulay force expressions, plus the repulsive
5//! pair potential gradient.
6
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10use super::params::get_xtb_params;
11use super::solver::{solve_xtb_with_state, sto_overlap, ANGSTROM_TO_BOHR, EV_PER_HARTREE};
12
13/// Result of xTB gradient computation.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct XtbGradientResult {
16    /// Energy gradient dE/dR per atom in eV/Å.
17    pub gradients: Vec<[f64; 3]>,
18    /// Total xTB energy in eV.
19    pub energy: f64,
20}
21
22/// Compute analytical xTB energy gradients.
23///
24/// Runs a full xTB SCC, then computes the gradient from the converged
25/// density using Hellmann-Feynman + Pulay + repulsive pair potential derivatives.
26pub fn compute_xtb_gradient(
27    elements: &[u8],
28    positions: &[[f64; 3]],
29) -> Result<XtbGradientResult, String> {
30    let (result, state) = solve_xtb_with_state(elements, positions)?;
31
32    let n_atoms = elements.len();
33    let n_basis = state.basis_map.len();
34    let n_occ = state.n_occ;
35
36    // Energy-weighted density: W_ij = 2·Σ_{k∈occ} ε_k·C_ik·C_jk
37    let mut w_mat = DMatrix::zeros(n_basis, n_basis);
38    for i in 0..n_basis {
39        for j in 0..n_basis {
40            let mut val = 0.0;
41            for k in 0..n_occ.min(n_basis) {
42                val += state.orbital_energies[k]
43                    * state.coefficients[(i, k)]
44                    * state.coefficients[(j, k)];
45            }
46            w_mat[(i, j)] = 2.0 * val;
47        }
48    }
49
50    let mut gradients = vec![[0.0f64; 3]; n_atoms];
51    let h_step = 1e-6;
52    let k_wh = 1.75;
53    let rep_alpha = 6.0;
54
55    // Compute per-pair gradient contribution (returns grad_a; grad_b = -grad_a)
56    let compute_pair = |a: usize, b: usize| -> [f64; 3] {
57        let pa = get_xtb_params(elements[a]).unwrap();
58        let pb = get_xtb_params(elements[b]).unwrap();
59
60        let dx = positions[a][0] - positions[b][0];
61        let dy = positions[a][1] - positions[b][1];
62        let dz = positions[a][2] - positions[b][2];
63        let r_ang = (dx * dx + dy * dy + dz * dz).sqrt();
64        if r_ang < 0.01 {
65            return [0.0; 3];
66        }
67        let r_bohr = r_ang * ANGSTROM_TO_BOHR;
68        let dir = [dx / r_ang, dy / r_ang, dz / r_ang];
69        let mut grad_a = [0.0f64; 3];
70
71        // ── 1. Repulsive pair potential gradient ──
72        let r_ref = pa.r_cov + pb.r_cov;
73        let na = pa.n_valence as f64;
74        let nb = pb.n_valence as f64;
75        let exp_term = (-rep_alpha * (r_ang / r_ref - 1.0)).exp();
76        let de_rep_dr = na * nb * EV_PER_HARTREE * exp_term / (r_ang * ANGSTROM_TO_BOHR)
77            * (-1.0 / r_ang - rep_alpha / r_ref);
78        for d in 0..3 {
79            grad_a[d] += de_rep_dr * dir[d];
80        }
81
82        // ── 2. SCC charge-shift gradient ──
83        let eta_sum_sq = (1.0 / pa.eta + 1.0 / pb.eta).powi(2);
84        let gamma_denom = (eta_sum_sq + r_bohr * r_bohr).sqrt();
85        let dgamma_dr_bohr = -r_bohr / (gamma_denom * gamma_denom * gamma_denom);
86        let dgamma_dr_ang = dgamma_dr_bohr * ANGSTROM_TO_BOHR;
87        let mut pop_a = 0.0;
88        let mut pop_b = 0.0;
89        for mu in 0..n_basis {
90            if state.basis_map[mu].0 == a {
91                pop_a += state.density[(mu, mu)];
92            }
93            if state.basis_map[mu].0 == b {
94                pop_b += state.density[(mu, mu)];
95            }
96        }
97        let de_scc_dr = 0.5 * (pop_a * state.charges[b] + pop_b * state.charges[a]) * dgamma_dr_ang;
98        for d in 0..3 {
99            grad_a[d] += de_scc_dr * dir[d];
100        }
101
102        // ── 3. Hellmann-Feynman + Pulay ──
103        for mu in 0..n_basis {
104            if state.basis_map[mu].0 != a {
105                continue;
106            }
107            let la = state.basis_map[mu].1;
108            for nu in 0..n_basis {
109                if state.basis_map[nu].0 != b {
110                    continue;
111                }
112                let lb = state.basis_map[nu].1;
113                let za = match la {
114                    0 => pa.zeta_s,
115                    1 => pa.zeta_p,
116                    _ => pa.zeta_d,
117                };
118                let zb = match lb {
119                    0 => pb.zeta_s,
120                    1 => pb.zeta_p,
121                    _ => pb.zeta_d,
122                };
123                if za < 1e-10 || zb < 1e-10 {
124                    continue;
125                }
126                let scale = if la == 0 && lb == 0 {
127                    1.0
128                } else if la == lb {
129                    0.5
130                } else {
131                    0.6
132                };
133                let s_plus = sto_overlap(za, zb, r_bohr + h_step);
134                let s_minus = sto_overlap(za, zb, r_bohr - h_step);
135                let ds_dr_bohr = (s_plus - s_minus) / (2.0 * h_step) * scale;
136                let ds_dr_ang = ds_dr_bohr * ANGSTROM_TO_BOHR;
137                let h_ii = state.h_diag[mu];
138                let h_jj = state.h_diag[nu];
139                let dh_dr = 0.5 * k_wh * (h_ii + h_jj) * ds_dr_ang;
140                let p_mn = state.density[(mu, nu)];
141                let w_mn = w_mat[(mu, nu)];
142                let force = 2.0 * (p_mn * dh_dr - w_mn * ds_dr_ang);
143                for d in 0..3 {
144                    grad_a[d] += force * dir[d];
145                }
146            }
147        }
148
149        grad_a
150    };
151
152    let pairs: Vec<(usize, usize)> = (0..n_atoms)
153        .flat_map(|a| ((a + 1)..n_atoms).map(move |b| (a, b)))
154        .collect();
155
156    #[cfg(feature = "parallel")]
157    {
158        use rayon::prelude::*;
159        let pair_grads: Vec<(usize, usize, [f64; 3])> = pairs
160            .par_iter()
161            .map(|&(a, b)| (a, b, compute_pair(a, b)))
162            .collect();
163        for (a, b, g) in pair_grads {
164            for d in 0..3 {
165                gradients[a][d] += g[d];
166                gradients[b][d] -= g[d];
167            }
168        }
169    }
170
171    #[cfg(not(feature = "parallel"))]
172    {
173        for &(a, b) in &pairs {
174            let g = compute_pair(a, b);
175            for d in 0..3 {
176                gradients[a][d] += g[d];
177                gradients[b][d] -= g[d];
178            }
179        }
180    }
181
182    Ok(XtbGradientResult {
183        gradients,
184        energy: result.total_energy,
185    })
186}
187
188#[cfg(test)]
189mod tests {
190    use super::super::solver::solve_xtb;
191    use super::*;
192
193    #[test]
194    fn test_xtb_gradient_h2() {
195        let elements = [1u8, 1];
196        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
197        let result = compute_xtb_gradient(&elements, &positions).unwrap();
198        assert_eq!(result.gradients.len(), 2);
199        assert!(result.energy.is_finite());
200        for d in 0..3 {
201            assert!(
202                (result.gradients[0][d] + result.gradients[1][d]).abs() < 0.1,
203                "Forces not equal and opposite: {:?}",
204                result.gradients
205            );
206        }
207    }
208
209    #[test]
210    fn test_xtb_gradient_water() {
211        let elements = [8u8, 1, 1];
212        let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
213        let result = compute_xtb_gradient(&elements, &positions).unwrap();
214        assert_eq!(result.gradients.len(), 3);
215        for g in &result.gradients {
216            for &v in g {
217                assert!(v.is_finite(), "Gradient must be finite");
218            }
219        }
220        // Net force ~ zero (translational invariance)
221        for d in 0..3 {
222            let sum: f64 = result.gradients.iter().map(|g| g[d]).sum();
223            assert!(
224                sum.abs() < 1.0,
225                "Net force should be near zero, got {sum:.4}"
226            );
227        }
228    }
229
230    #[test]
231    fn test_xtb_gradient_vs_numerical() {
232        let elements = [1u8, 1];
233        let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
234        let analytical = compute_xtb_gradient(&elements, &positions).unwrap();
235
236        let h = 1e-5;
237        for a in 0..2 {
238            for d in 0..3 {
239                let mut pos_p = positions.to_vec();
240                let mut pos_m = positions.to_vec();
241                pos_p[a][d] += h;
242                pos_m[a][d] -= h;
243                let e_p = solve_xtb(&elements, &pos_p).unwrap().total_energy;
244                let e_m = solve_xtb(&elements, &pos_m).unwrap().total_energy;
245                let numerical = (e_p - e_m) / (2.0 * h);
246                let diff = (analytical.gradients[a][d] - numerical).abs();
247                let scale = numerical.abs().max(1.0);
248                assert!(
249                    diff / scale < 0.5,
250                    "Gradient mismatch atom {a} dir {d}: analytical={:.6} numerical={:.6}",
251                    analytical.gradients[a][d],
252                    numerical,
253                );
254            }
255        }
256    }
257}