1use nalgebra::DMatrix;
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct Gfn1Result {
15 pub orbital_energies: Vec<f64>,
17 pub electronic_energy: f64,
19 pub repulsive_energy: f64,
21 pub dispersion_energy: f64,
23 pub total_energy: f64,
25 pub n_basis: usize,
27 pub n_electrons: usize,
29 pub homo_energy: f64,
31 pub lumo_energy: f64,
33 pub gap: f64,
35 pub mulliken_charges: Vec<f64>,
37 pub shell_charges: Vec<Vec<f64>>,
39 pub scc_iterations: usize,
41 pub converged: bool,
43}
44
45#[derive(Debug, Clone)]
47pub struct Gfn1ShellParams {
48 pub l: u8,
50 pub h_level: f64,
52 pub zeta: f64,
54 pub eta: f64,
56 pub occ: f64,
58}
59
60pub fn solve_gfn1(elements: &[u8], positions: &[[f64; 3]]) -> Result<Gfn1Result, String> {
68 use crate::xtb::solver::solve_xtb_with_state;
69
70 for &z in elements {
72 if crate::xtb::params::get_xtb_params(z).is_none() {
73 return Err(format!("Element Z={} not supported by GFN1-xTB", z));
74 }
75 }
76
77 let n_atoms = elements.len();
78
79 let (gfn0, state) = solve_xtb_with_state(elements, positions)?;
81
82 let n_basis = state.basis_map.len();
83 let n_electrons = gfn0.n_electrons;
84 let n_occ = state.n_occ;
85
86 let mut shell_list: Vec<(usize, u8)> = Vec::new();
88 let mut basis_to_shell: Vec<usize> = Vec::with_capacity(n_basis);
89 for &(atom, l, _m) in &state.basis_map {
90 let shell_idx = shell_list
91 .iter()
92 .position(|&s| s == (atom, l))
93 .unwrap_or_else(|| {
94 shell_list.push((atom, l));
95 shell_list.len() - 1
96 });
97 basis_to_shell.push(shell_idx);
98 }
99 let n_shells = shell_list.len();
100
101 let ref_pop = compute_reference_populations(elements, &shell_list);
103
104 let shell_eta: Vec<f64> = shell_list
106 .iter()
107 .map(|&(atom, l)| {
108 let eta = crate::xtb::params::get_xtb_params(elements[atom])
109 .unwrap()
110 .eta;
111 match l {
112 0 => eta,
113 1 => eta * 0.85,
114 _ => eta * 0.70,
115 }
116 })
117 .collect();
118
119 let gamma = build_shell_gamma_matrix(positions, &shell_list, &shell_eta);
121
122 let mut shell_dq = mulliken_shell_charges(
124 &state.density,
125 &state.overlap,
126 &basis_to_shell,
127 &ref_pop,
128 n_shells,
129 n_basis,
130 );
131
132 let max_scc = 100;
134 let scc_tol = 1e-6;
135 let damp = 0.4;
136 let mut converged = false;
137 let mut scc_iter = 0;
138 let mut orbital_energies = state.orbital_energies.clone();
139 let mut coefficients = state.coefficients.clone();
140 let mut density = state.density.clone();
141 let mut prev_e_elec = 0.0;
142
143 for iter in 0..max_scc {
144 scc_iter = iter + 1;
145
146 let mut h_scc = state.hamiltonian.clone();
148 for mu in 0..n_basis {
149 let s_mu = basis_to_shell[mu];
150 let mut shift = 0.0;
151 for s in 0..n_shells {
152 shift += gamma[(s_mu, s)] * shell_dq[s];
153 }
154 h_scc[(mu, mu)] += shift;
155 }
156
157 let f_prime = &state.s_half_inv * &h_scc * &state.s_half_inv;
159 let eigen = f_prime.symmetric_eigen();
160
161 let mut indices: Vec<usize> = (0..n_basis).collect();
162 indices.sort_by(|&a, &b| {
163 eigen.eigenvalues[a]
164 .partial_cmp(&eigen.eigenvalues[b])
165 .unwrap_or(std::cmp::Ordering::Equal)
166 });
167
168 for (new_idx, &old_idx) in indices.iter().enumerate() {
169 orbital_energies[new_idx] = eigen.eigenvalues[old_idx];
170 }
171
172 let c_prime = &eigen.eigenvectors;
173 let c_full = &state.s_half_inv * c_prime;
174 for new_idx in 0..n_basis {
175 let old_idx = indices[new_idx];
176 for i in 0..n_basis {
177 coefficients[(i, new_idx)] = c_full[(i, old_idx)];
178 }
179 }
180
181 density = DMatrix::zeros(n_basis, n_basis);
183 for i in 0..n_basis {
184 for j in 0..n_basis {
185 let mut val = 0.0;
186 for k in 0..n_occ.min(n_basis) {
187 val += coefficients[(i, k)] * coefficients[(j, k)];
188 }
189 density[(i, j)] = 2.0 * val;
190 }
191 }
192
193 let new_dq = mulliken_shell_charges(
195 &density,
196 &state.overlap,
197 &basis_to_shell,
198 &ref_pop,
199 n_shells,
200 n_basis,
201 );
202
203 let mut e_elec = 0.0;
205 for i in 0..n_basis {
206 for j in 0..n_basis {
207 e_elec += 0.5 * density[(i, j)] * (state.hamiltonian[(i, j)] + h_scc[(i, j)]);
208 }
209 }
210
211 if (e_elec - prev_e_elec).abs() < scc_tol && iter > 0 {
212 converged = true;
213 prev_e_elec = e_elec;
214 shell_dq = new_dq;
215 break;
216 }
217 prev_e_elec = e_elec;
218
219 for s in 0..n_shells {
221 shell_dq[s] = damp * shell_dq[s] + (1.0 - damp) * new_dq[s];
222 }
223 }
224
225 let ps = &density * &state.overlap;
227 let mut mulliken_charges = Vec::with_capacity(n_atoms);
228 for a in 0..n_atoms {
229 let pa = crate::xtb::params::get_xtb_params(elements[a]).unwrap();
230 let mut pop = 0.0;
231 for mu in 0..n_basis {
232 if state.basis_map[mu].0 == a {
233 pop += ps[(mu, mu)];
234 }
235 }
236 mulliken_charges.push(pa.n_valence as f64 - pop);
237 }
238
239 let mut atom_shell_charges = vec![vec![0.0; 3]; n_atoms];
241 for (s, &(atom, l)) in shell_list.iter().enumerate() {
242 atom_shell_charges[atom][l as usize] = shell_dq[s];
243 }
244
245 let disp_energy = compute_d3bj_correction(elements, positions);
247
248 let rep_energy = compute_gfn1_repulsive(elements, positions);
250
251 let e_elec = prev_e_elec;
252 let total_energy = e_elec + rep_energy + disp_energy;
253
254 let homo_energy = if n_occ > 0 && n_occ <= orbital_energies.len() {
255 orbital_energies[n_occ - 1]
256 } else {
257 0.0
258 };
259 let lumo_energy = if n_occ < orbital_energies.len() {
260 orbital_energies[n_occ]
261 } else {
262 0.0
263 };
264
265 Ok(Gfn1Result {
266 orbital_energies,
267 electronic_energy: e_elec,
268 repulsive_energy: rep_energy,
269 dispersion_energy: disp_energy,
270 total_energy,
271 n_basis,
272 n_electrons,
273 homo_energy,
274 lumo_energy,
275 gap: lumo_energy - homo_energy,
276 mulliken_charges,
277 shell_charges: atom_shell_charges,
278 scc_iterations: scc_iter,
279 converged,
280 })
281}
282
283fn compute_reference_populations(elements: &[u8], shell_list: &[(usize, u8)]) -> Vec<f64> {
285 let mut ref_pop = vec![0.0; shell_list.len()];
286 for (idx, &(atom, l)) in shell_list.iter().enumerate() {
287 let params = crate::xtb::params::get_xtb_params(elements[atom]).unwrap();
288 let n_val = params.n_valence as f64;
289 let has_p = params.zeta_p > 0.0;
290 ref_pop[idx] = match l {
291 0 => n_val.clamp(0.0, 2.0),
292 1 => (n_val - 2.0).clamp(0.0, 6.0),
293 _ => {
294 let used = 2.0 + if has_p { 6.0 } else { 0.0 };
295 (n_val - used).clamp(0.0, 10.0)
296 }
297 };
298 }
299 ref_pop
300}
301
302fn mulliken_shell_charges(
304 density: &DMatrix<f64>,
305 overlap: &DMatrix<f64>,
306 basis_to_shell: &[usize],
307 ref_pop: &[f64],
308 n_shells: usize,
309 n_basis: usize,
310) -> Vec<f64> {
311 let ps = density * overlap;
312 let mut pop = vec![0.0; n_shells];
313 for mu in 0..n_basis {
314 pop[basis_to_shell[mu]] += ps[(mu, mu)];
315 }
316 let mut dq = vec![0.0; n_shells];
317 for s in 0..n_shells {
318 dq[s] = ref_pop[s] - pop[s];
319 }
320 dq
321}
322
323fn build_shell_gamma_matrix(
328 positions: &[[f64; 3]],
329 shell_list: &[(usize, u8)],
330 shell_eta: &[f64],
331) -> DMatrix<f64> {
332 let n = shell_list.len();
333 let mut gamma = DMatrix::zeros(n, n);
334
335 for i in 0..n {
336 let (atom_i, _) = shell_list[i];
337 gamma[(i, i)] = shell_eta[i];
338
339 for j in (i + 1)..n {
340 let (atom_j, _) = shell_list[j];
341
342 let gamma_ij = if atom_i == atom_j {
343 shell_eta[i] * shell_eta[j] / (shell_eta[i] + shell_eta[j])
345 } else {
346 let dx = positions[atom_i][0] - positions[atom_j][0];
348 let dy = positions[atom_i][1] - positions[atom_j][1];
349 let dz = positions[atom_i][2] - positions[atom_j][2];
350 let r_bohr = (dx * dx + dy * dy + dz * dz).sqrt() / 0.529177;
351 1.0 / ((1.0 / shell_eta[i] + 1.0 / shell_eta[j]).powi(2) + r_bohr.powi(2)).sqrt()
352 };
353
354 gamma[(i, j)] = gamma_ij;
355 gamma[(j, i)] = gamma_ij;
356 }
357 }
358
359 gamma
360}
361
362fn compute_d3bj_correction(elements: &[u8], positions: &[[f64; 3]]) -> f64 {
364 let n = elements.len();
365 let mut e_disp = 0.0;
366
367 let s6 = 1.0;
369 let s8 = 2.4;
370 let a1 = 0.63;
371 let a2 = 5.0;
372
373 for i in 0..n {
374 for j in (i + 1)..n {
375 let dx = positions[i][0] - positions[j][0];
376 let dy = positions[i][1] - positions[j][1];
377 let dz = positions[i][2] - positions[j][2];
378 let r = (dx * dx + dy * dy + dz * dz).sqrt();
379
380 if !(0.1..=50.0).contains(&r) {
381 continue;
382 }
383
384 let c6 = get_c6(elements[i], elements[j]);
385 let q_ij = get_r2r4(elements[i]) * get_r2r4(elements[j]);
390 let c8 = 3.0 * c6 * q_ij * q_ij;
391
392 let r0 = (c8 / c6).sqrt();
393 let f6 = 1.0 / (r.powi(6) + (a1 * r0 + a2).powi(6));
394 let f8 = 1.0 / (r.powi(8) + (a1 * r0 + a2).powi(8));
395
396 e_disp -= s6 * c6 * f6 + s8 * c8 * f8;
397 }
398 }
399
400 e_disp * 27.2114
402}
403
404fn compute_gfn1_repulsive(elements: &[u8], positions: &[[f64; 3]]) -> f64 {
406 let n = elements.len();
407 let mut e_rep = 0.0;
408
409 for i in 0..n {
410 let pi = crate::xtb::params::get_xtb_params(elements[i]).unwrap();
411 for j in (i + 1)..n {
412 let pj = crate::xtb::params::get_xtb_params(elements[j]).unwrap();
413 let dx = positions[i][0] - positions[j][0];
414 let dy = positions[i][1] - positions[j][1];
415 let dz = positions[i][2] - positions[j][2];
416 let r = (dx * dx + dy * dy + dz * dz).sqrt();
417
418 if r < 0.1 {
419 continue;
420 }
421
422 let r_ab = pi.r_cov + pj.r_cov;
423 let z_eff_i = pi.n_valence as f64;
424 let z_eff_j = pj.n_valence as f64;
425 let alpha = (z_eff_i * z_eff_j).sqrt();
426
427 e_rep += alpha * (-1.5 * r / r_ab).exp();
428 }
429 }
430
431 e_rep * 27.2114 }
433
434fn get_c6(z1: u8, z2: u8) -> f64 {
436 let c6_1 = atomic_c6(z1);
437 let c6_2 = atomic_c6(z2);
438 (2.0 * c6_1 * c6_2) / (c6_1 + c6_2 + 1e-30)
439}
440
441fn atomic_c6(z: u8) -> f64 {
442 match z {
443 1 => 6.5,
444 6 => 46.6,
445 7 => 24.2,
446 8 => 15.6,
447 9 => 9.52,
448 14 => 305.0,
449 15 => 185.0,
450 16 => 134.0,
451 17 => 94.6,
452 35 => 162.0,
453 22 => 1044.0,
454 24 => 602.0,
455 25 => 552.0,
456 26 => 482.0,
457 27 => 408.0,
458 28 => 373.0,
459 29 => 253.0,
460 30 => 284.0,
461 _ => 50.0,
462 }
463}
464
465fn get_r2r4(z: u8) -> f64 {
466 match z {
467 1 => 2.00,
468 6 => 3.09,
469 7 => 2.71,
470 8 => 2.44,
471 9 => 1.91,
472 14 => 4.17,
473 15 => 3.63,
474 16 => 3.49,
475 17 => 3.01,
476 35 => 3.47,
477 _ => 3.0,
478 }
479}
480
481#[cfg(test)]
482mod tests {
483 use super::*;
484
485 #[test]
486 fn test_gfn1_water() {
487 let elements = vec![8u8, 1, 1];
488 let positions = vec![
489 [0.0, 0.0, 0.117],
490 [0.0, 0.757, -0.469],
491 [0.0, -0.757, -0.469],
492 ];
493 let result = solve_gfn1(&elements, &positions);
494 assert!(result.is_ok());
495 let r = result.unwrap();
496 assert!(r.total_energy.is_finite());
497 assert!(r.gap > 0.0);
498 }
499}