1use super::params::{count_xtb_electrons, get_xtb_params, num_xtb_basis_functions};
7use nalgebra::DMatrix;
8use serde::{Deserialize, Serialize};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct XtbResult {
13 pub orbital_energies: Vec<f64>,
15 pub electronic_energy: f64,
17 pub repulsive_energy: f64,
19 pub total_energy: f64,
21 pub n_basis: usize,
23 pub n_electrons: usize,
25 pub homo_energy: f64,
27 pub lumo_energy: f64,
29 pub gap: f64,
31 pub mulliken_charges: Vec<f64>,
33 pub scc_iterations: usize,
35 pub converged: bool,
37}
38
39const ANGSTROM_TO_BOHR: f64 = 1.0 / 0.529177;
40const EV_PER_HARTREE: f64 = 27.2114;
41
42fn distance_bohr(a: &[f64; 3], b: &[f64; 3]) -> f64 {
44 let dx = (a[0] - b[0]) * ANGSTROM_TO_BOHR;
45 let dy = (a[1] - b[1]) * ANGSTROM_TO_BOHR;
46 let dz = (a[2] - b[2]) * ANGSTROM_TO_BOHR;
47 (dx * dx + dy * dy + dz * dz).sqrt()
48}
49
50fn sto_overlap(zeta_a: f64, zeta_b: f64, r_bohr: f64) -> f64 {
52 if r_bohr < 1e-10 {
53 return if (zeta_a - zeta_b).abs() < 1e-10 {
54 1.0
55 } else {
56 0.0
57 };
58 }
59 let p = 0.5 * (zeta_a + zeta_b) * r_bohr;
60 (-p).exp() * (1.0 + p + p * p / 3.0)
61}
62
63fn build_basis_map(elements: &[u8]) -> Vec<(usize, u8, u8)> {
65 let mut basis = Vec::new();
66 for (i, &z) in elements.iter().enumerate() {
67 let n = num_xtb_basis_functions(z);
68 if n >= 1 {
69 basis.push((i, 0, 0));
70 } if n >= 4 {
72 basis.push((i, 1, 0)); basis.push((i, 1, 1)); basis.push((i, 1, 2)); }
76 if n >= 9 {
77 for m in 0..5u8 {
78 basis.push((i, 2, m));
79 } }
81 }
82 basis
83}
84
85pub fn solve_xtb(elements: &[u8], positions: &[[f64; 3]]) -> Result<XtbResult, String> {
90 if elements.len() != positions.len() {
91 return Err(format!(
92 "elements ({}) and positions ({}) length mismatch",
93 elements.len(),
94 positions.len()
95 ));
96 }
97
98 for &z in elements {
99 if get_xtb_params(z).is_none() {
100 return Err(format!("xTB parameters not available for Z={}", z));
101 }
102 }
103
104 let n_atoms = elements.len();
105 let basis_map = build_basis_map(elements);
106 let n_basis = basis_map.len();
107 let n_electrons = count_xtb_electrons(elements);
108 let n_occ = n_electrons / 2;
109
110 if n_basis == 0 {
111 return Err("No basis functions".to_string());
112 }
113
114 let mut s_mat = DMatrix::zeros(n_basis, n_basis);
116 for i in 0..n_basis {
117 s_mat[(i, i)] = 1.0;
118 let (atom_a, la, _) = basis_map[i];
119 for j in (i + 1)..n_basis {
120 let (atom_b, lb, _) = basis_map[j];
121 if atom_a == atom_b {
122 continue;
123 }
124 let r = distance_bohr(&positions[atom_a], &positions[atom_b]);
125 let pa = get_xtb_params(elements[atom_a]).unwrap();
126 let pb = get_xtb_params(elements[atom_b]).unwrap();
127 let za = match la {
128 0 => pa.zeta_s,
129 1 => pa.zeta_p,
130 _ => pa.zeta_d,
131 };
132 let zb = match lb {
133 0 => pb.zeta_s,
134 1 => pb.zeta_p,
135 _ => pb.zeta_d,
136 };
137 if za < 1e-10 || zb < 1e-10 {
138 continue;
139 }
140 let scale = if la == 0 && lb == 0 {
142 1.0
143 } else if la == lb {
144 0.5
145 } else {
146 0.6
147 };
148 let sij = sto_overlap(za, zb, r) * scale;
149 s_mat[(i, j)] = sij;
150 s_mat[(j, i)] = sij;
151 }
152 }
153
154 let mut h_mat = DMatrix::zeros(n_basis, n_basis);
156 for i in 0..n_basis {
157 let (atom_a, la, _) = basis_map[i];
158 let pa = get_xtb_params(elements[atom_a]).unwrap();
159 h_mat[(i, i)] = match la {
160 0 => pa.h_s,
161 1 => pa.h_p,
162 _ => pa.h_d,
163 };
164 }
165 for i in 0..n_basis {
166 for j in (i + 1)..n_basis {
167 let (atom_a, _, _) = basis_map[i];
168 let (atom_b, _, _) = basis_map[j];
169 if atom_a == atom_b {
170 continue;
171 }
172 let k_wh = 1.75;
173 let hij = 0.5 * k_wh * s_mat[(i, j)] * (h_mat[(i, i)] + h_mat[(j, j)]);
174 h_mat[(i, j)] = hij;
175 h_mat[(j, i)] = hij;
176 }
177 }
178
179 let mut e_rep = 0.0;
181 for a in 0..n_atoms {
182 let pa = get_xtb_params(elements[a]).unwrap();
183 for b in (a + 1)..n_atoms {
184 let pb = get_xtb_params(elements[b]).unwrap();
185 let r_ang = {
186 let dx = positions[a][0] - positions[b][0];
187 let dy = positions[a][1] - positions[b][1];
188 let dz = positions[a][2] - positions[b][2];
189 (dx * dx + dy * dy + dz * dz).sqrt()
190 };
191 if r_ang < 0.1 {
192 continue;
193 }
194 let r_ref = pa.r_cov + pb.r_cov;
195 let alpha = 6.0; e_rep += (pa.n_valence as f64) * (pb.n_valence as f64) * EV_PER_HARTREE
198 / (r_ang * ANGSTROM_TO_BOHR)
199 * (-alpha * (r_ang / r_ref - 1.0)).exp();
200 }
201 }
202
203 let max_iter = 50;
205 let convergence = 1e-6;
206 let mut charges = vec![0.0f64; n_atoms];
207 let mut orbital_energies = vec![0.0; n_basis];
208 let mut coefficients = DMatrix::zeros(n_basis, n_basis);
209 let mut converged = false;
210 let mut scc_iter = 0;
211 let mut prev_e_elec = 0.0;
212
213 let s_eigen = s_mat.clone().symmetric_eigen();
215 let mut s_half_inv = DMatrix::zeros(n_basis, n_basis);
216 for k in 0..n_basis {
217 let val = s_eigen.eigenvalues[k];
218 if val > 1e-8 {
219 let inv_sqrt = 1.0 / val.sqrt();
220 let col = s_eigen.eigenvectors.column(k);
221 for i in 0..n_basis {
222 for j in 0..n_basis {
223 s_half_inv[(i, j)] += inv_sqrt * col[i] * col[j];
224 }
225 }
226 }
227 }
228
229 for iter in 0..max_iter {
230 scc_iter = iter + 1;
231
232 let mut h_scc = h_mat.clone();
234 for i in 0..n_basis {
235 let (atom_a, _, _) = basis_map[i];
236 let pa = get_xtb_params(elements[atom_a]).unwrap();
237 let mut shift = 0.0;
239 for b in 0..n_atoms {
240 if b == atom_a {
241 continue;
242 }
243 let pb = get_xtb_params(elements[b]).unwrap();
244 let r_bohr = distance_bohr(&positions[atom_a], &positions[b]);
245 let gamma = 1.0 / ((1.0 / pa.eta + 1.0 / pb.eta).powi(2) + r_bohr.powi(2)).sqrt();
246 shift += gamma * charges[b];
247 }
248 shift += pa.eta * charges[atom_a];
250 h_scc[(i, i)] += shift;
251 }
252
253 let f_prime = &s_half_inv * &h_scc * &s_half_inv;
255 let eigen = f_prime.symmetric_eigen();
256
257 let mut indices: Vec<usize> = (0..n_basis).collect();
258 indices.sort_by(|&a, &b| {
259 eigen.eigenvalues[a]
260 .partial_cmp(&eigen.eigenvalues[b])
261 .unwrap_or(std::cmp::Ordering::Equal)
262 });
263
264 for (new_idx, &old_idx) in indices.iter().enumerate() {
265 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
266 }
267
268 let c_prime = &eigen.eigenvectors;
269 let c_full = &s_half_inv * c_prime;
270 for new_idx in 0..n_basis {
271 let old_idx = indices[new_idx];
272 for i in 0..n_basis {
273 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
274 }
275 }
276
277 let mut density = DMatrix::zeros(n_basis, n_basis);
279 for i in 0..n_basis {
280 for j in 0..n_basis {
281 let mut val = 0.0;
282 for k in 0..n_occ.min(n_basis) {
283 val += coefficients[(i, k)] * coefficients[(j, k)];
284 }
285 density[(i, j)] = 2.0 * val;
286 }
287 }
288
289 let ps = &density * &s_mat;
291 let mut new_charges = Vec::with_capacity(n_atoms);
292 for a in 0..n_atoms {
293 let pa = get_xtb_params(elements[a]).unwrap();
294 let mut pop = 0.0;
295 for i in 0..n_basis {
296 if basis_map[i].0 == a {
297 pop += ps[(i, i)];
298 }
299 }
300 new_charges.push(pa.n_valence as f64 - pop);
301 }
302
303 let mut e_elec = 0.0;
305 for i in 0..n_basis {
306 for j in 0..n_basis {
307 e_elec += 0.5 * density[(i, j)] * (h_mat[(i, j)] + h_scc[(i, j)]);
308 }
309 }
310
311 if (e_elec - prev_e_elec).abs() < convergence && iter > 0 {
312 converged = true;
313 charges = new_charges;
314 break;
315 }
316 prev_e_elec = e_elec;
317
318 let damp = 0.4;
320 for a in 0..n_atoms {
321 charges[a] = damp * charges[a] + (1.0 - damp) * new_charges[a];
322 }
323 }
324
325 let e_elec = prev_e_elec;
327 let total_energy = e_elec + e_rep;
328
329 let homo_idx = if n_occ > 0 { n_occ - 1 } else { 0 };
330 let lumo_idx = n_occ.min(n_basis - 1);
331 let homo_energy = orbital_energies[homo_idx];
332 let lumo_energy = if n_occ < n_basis {
333 orbital_energies[lumo_idx]
334 } else {
335 homo_energy
336 };
337 let gap = if n_occ < n_basis {
338 lumo_energy - homo_energy
339 } else {
340 0.0
341 };
342
343 Ok(XtbResult {
344 orbital_energies,
345 electronic_energy: e_elec,
346 repulsive_energy: e_rep,
347 total_energy,
348 n_basis,
349 n_electrons,
350 homo_energy,
351 lumo_energy,
352 gap,
353 mulliken_charges: charges,
354 scc_iterations: scc_iter,
355 converged,
356 })
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362
363 #[test]
364 fn test_xtb_h2() {
365 let elements = [1u8, 1];
366 let positions = [[0.0, 0.0, 0.0], [0.74, 0.0, 0.0]];
367 let result = solve_xtb(&elements, &positions).unwrap();
368 assert_eq!(result.n_basis, 2);
369 assert_eq!(result.n_electrons, 2);
370 assert!(result.total_energy.is_finite());
371 assert!(result.gap >= 0.0);
372 }
373
374 #[test]
375 fn test_xtb_water() {
376 let elements = [8u8, 1, 1];
377 let positions = [[0.0, 0.0, 0.0], [0.757, 0.586, 0.0], [-0.757, 0.586, 0.0]];
378 let result = solve_xtb(&elements, &positions).unwrap();
379 assert_eq!(result.n_basis, 6);
380 assert_eq!(result.n_electrons, 8);
381 assert!(result.total_energy.is_finite());
382 assert!(result.gap > 0.0, "Water should have a positive gap");
383 }
384
385 #[test]
386 fn test_xtb_ferrocene_atom() {
387 let elements = [26u8];
389 let positions = [[0.0, 0.0, 0.0]];
390 let result = solve_xtb(&elements, &positions).unwrap();
391 assert_eq!(result.n_basis, 9); assert_eq!(result.n_electrons, 8);
393 }
394
395 #[test]
396 fn test_xtb_unsupported() {
397 let elements = [92u8]; let positions = [[0.0, 0.0, 0.0]];
399 assert!(solve_xtb(&elements, &positions).is_err());
400 }
401}