1use super::basis::{build_sto3g_basis, ANG_TO_BOHR};
7use super::cis::{compute_cis_with_dipole, CisResult};
8use super::d3::compute_d3_energy;
9use super::fock::nuclear_repulsion;
10use super::gcp::compute_gcp;
11use super::integrals::compute_eris;
12use super::nuclear::compute_nuclear_matrix;
13use super::overlap_kin::{compute_kinetic_matrix, compute_overlap_matrix};
14use super::scf::{solve_scf, ScfConfig};
15use super::srb::compute_srb;
16use serde::{Deserialize, Serialize};
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct HfConfig {
21 pub max_scf_iter: usize,
23 pub diis_size: usize,
25 pub n_cis_states: usize,
27 pub corrections: bool,
29}
30
31impl Default for HfConfig {
32 fn default() -> Self {
33 HfConfig {
34 max_scf_iter: 300,
35 diis_size: 6,
36 n_cis_states: 5,
37 corrections: true,
38 }
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Hf3cResult {
48 pub energy: f64,
50 pub hf_energy: f64,
52 pub nuclear_repulsion: f64,
54 pub d3_energy: f64,
56 pub gcp_energy: f64,
58 pub srb_energy: f64,
60 pub orbital_energies: Vec<f64>,
62 pub scf_iterations: usize,
64 pub converged: bool,
66 pub cis: Option<CisResult>,
68 pub n_basis: usize,
70 pub n_electrons: usize,
72 pub homo_energy: f64,
74 pub lumo_energy: Option<f64>,
76 pub gap: f64,
78 pub mulliken_charges: Vec<f64>,
80}
81
82#[cfg(feature = "experimental-gpu")]
83fn hf_basis_to_gpu_basis(basis: &super::basis::BasisSet) -> crate::scf::basis::BasisSet {
84 use crate::scf::basis::{
85 BasisFunction as GpuBasisFunction, BasisSet as GpuBasisSet,
86 ContractedShell as GpuContractedShell, GaussianPrimitive,
87 };
88
89 let mut functions = Vec::new();
90 let mut shells = Vec::new();
91 let mut function_to_atom = Vec::new();
92
93 for shell in &basis.shells {
94 let primitives: Vec<GaussianPrimitive> = shell
95 .exponents
96 .iter()
97 .zip(shell.coefficients.iter())
98 .map(|(&alpha, &coefficient)| GaussianPrimitive { alpha, coefficient })
99 .collect();
100
101 let l = match shell.shell_type {
102 super::basis::ShellType::S => 0,
103 super::basis::ShellType::P => 1,
104 };
105
106 shells.push(GpuContractedShell {
107 atom_index: shell.center_idx,
108 center: shell.center,
109 l,
110 primitives: primitives.clone(),
111 });
112
113 match shell.shell_type {
114 super::basis::ShellType::S => {
115 functions.push(GpuBasisFunction {
116 atom_index: shell.center_idx,
117 center: shell.center,
118 angular: [0, 0, 0],
119 l_total: 0,
120 primitives: primitives.clone(),
121 });
122 function_to_atom.push(shell.center_idx);
123 }
124 super::basis::ShellType::P => {
125 for angular in [[1, 0, 0], [0, 1, 0], [0, 0, 1]] {
126 functions.push(GpuBasisFunction {
127 atom_index: shell.center_idx,
128 center: shell.center,
129 angular,
130 l_total: 1,
131 primitives: primitives.clone(),
132 });
133 function_to_atom.push(shell.center_idx);
134 }
135 }
136 }
137 }
138
139 let n_basis = functions.len();
140 GpuBasisSet {
141 functions,
142 shells,
143 n_basis,
144 function_to_atom,
145 }
146}
147
148pub fn solve_hf3c(
150 elements: &[u8],
151 positions: &[[f64; 3]],
152 config: &HfConfig,
153) -> Result<Hf3cResult, String> {
154 if elements.len() != positions.len() {
155 return Err("elements/positions length mismatch".to_string());
156 }
157 if elements.is_empty() {
158 return Err("empty molecule".to_string());
159 }
160
161 let pos_bohr: Vec<[f64; 3]> = positions
163 .iter()
164 .map(|p| [p[0] * ANG_TO_BOHR, p[1] * ANG_TO_BOHR, p[2] * ANG_TO_BOHR])
165 .collect();
166
167 let basis = build_sto3g_basis(elements, positions);
169 let n_basis = basis.n_basis();
170
171 let n_electrons: usize = elements.iter().map(|&z| z as usize).sum();
173
174 let s_mat = compute_overlap_matrix(&basis);
176 let t_mat = compute_kinetic_matrix(&basis);
177 let v_mat = compute_nuclear_matrix(&basis, elements, &pos_bohr);
178 let h_core = &t_mat + &v_mat;
179
180 let eris = compute_eris(&basis);
182
183 #[cfg(feature = "experimental-gpu")]
184 let gpu_eris_full = if n_basis >= 4 {
185 let n4 = (n_basis as u64)
188 .saturating_mul(n_basis as u64)
189 .saturating_mul(n_basis as u64)
190 .saturating_mul(n_basis as u64);
191 let mem_bytes = n4.saturating_mul(8);
192 let max_mem: u64 = 512 * 1024 * 1024; if mem_bytes > max_mem {
195 None } else if let Ok(ctx) = crate::gpu::context::GpuContext::try_create() {
197 let gpu_basis = hf_basis_to_gpu_basis(&basis);
198 crate::gpu::two_electron_gpu::compute_eris_gpu(&ctx, &gpu_basis)
199 .ok()
200 .map(|gpu_eris| {
201 let cap = n_basis * n_basis * n_basis * n_basis;
202 let mut full = Vec::with_capacity(cap);
203 for mu in 0..n_basis {
204 for nu in 0..n_basis {
205 for lam in 0..n_basis {
206 for sig in 0..n_basis {
207 full.push(gpu_eris.get(mu, nu, lam, sig));
208 }
209 }
210 }
211 }
212 full
213 })
214 } else {
215 None
216 }
217 } else {
218 None
219 };
220
221 #[cfg(not(feature = "experimental-gpu"))]
222 let gpu_eris_full: Option<Vec<f64>> = None;
223
224 let scf_config = ScfConfig {
226 max_iter: config.max_scf_iter,
227 diis_size: config.diis_size,
228 ..ScfConfig::default()
229 };
230 let scf_result = solve_scf(
231 &h_core,
232 &s_mat,
233 &eris,
234 gpu_eris_full.as_deref(),
235 n_electrons,
236 &scf_config,
237 );
238
239 let e_nuc = nuclear_repulsion(elements, &pos_bohr);
241
242 let (d3_e, gcp_e, srb_e) = if config.corrections {
244 (
245 compute_d3_energy(elements, &pos_bohr).energy,
246 compute_gcp(elements, &pos_bohr),
247 compute_srb(elements, &pos_bohr),
248 )
249 } else {
250 (0.0, 0.0, 0.0)
251 };
252
253 let total = scf_result.energy + e_nuc + d3_e + gcp_e + srb_e;
254
255 let cis = if config.n_cis_states > 0 && scf_result.converged {
257 let n_occ = n_electrons / 2;
258 let ao_map = super::basis::ao_to_atom_map(&basis);
259 Some(compute_cis_with_dipole(
260 &scf_result.orbital_energies,
261 &scf_result.coefficients,
262 &eris,
263 n_basis,
264 n_occ,
265 config.n_cis_states,
266 Some(&pos_bohr),
267 Some(&ao_map),
268 ))
269 } else {
270 None
271 };
272
273 let n_occ = n_electrons / 2;
275 let homo_energy = if n_occ > 0 && n_occ <= scf_result.orbital_energies.len() {
276 scf_result.orbital_energies[n_occ - 1]
277 } else {
278 0.0
279 };
280 let lumo_energy = if n_occ < scf_result.orbital_energies.len() {
281 Some(scf_result.orbital_energies[n_occ])
282 } else {
283 None
284 };
285 let gap = lumo_energy.map_or(0.0, |l| l - homo_energy);
286
287 let mulliken_charges = if scf_result.converged {
289 let ps = &scf_result.density * &s_mat;
290 let ao_to_atom = super::basis::ao_to_atom_map(&basis);
291 let mut charges = vec![0.0_f64; elements.len()];
292 for mu in 0..n_basis {
293 charges[ao_to_atom[mu]] += ps[(mu, mu)];
294 }
295 charges
296 .iter()
297 .enumerate()
298 .map(|(i, &pop)| elements[i] as f64 - pop)
299 .collect()
300 } else {
301 vec![0.0; elements.len()]
302 };
303
304 Ok(Hf3cResult {
305 energy: total,
306 hf_energy: scf_result.energy + e_nuc,
307 nuclear_repulsion: e_nuc,
308 d3_energy: d3_e,
309 gcp_energy: gcp_e,
310 srb_energy: srb_e,
311 orbital_energies: scf_result.orbital_energies,
312 scf_iterations: scf_result.iterations,
313 converged: scf_result.converged,
314 cis,
315 n_basis,
316 n_electrons,
317 homo_energy,
318 lumo_energy,
319 gap,
320 mulliken_charges,
321 })
322}
323
324#[cfg(feature = "parallel")]
326pub fn solve_hf3c_batch(
327 molecules: &[(&[u8], &[[f64; 3]])],
328 config: &HfConfig,
329) -> Vec<Result<Hf3cResult, String>> {
330 use rayon::prelude::*;
331 molecules
332 .par_iter()
333 .map(|(els, pos)| solve_hf3c(els, pos, config))
334 .collect()
335}
336
337#[cfg(not(feature = "parallel"))]
339pub fn solve_hf3c_batch(
340 molecules: &[(&[u8], &[[f64; 3]])],
341 config: &HfConfig,
342) -> Vec<Result<Hf3cResult, String>> {
343 molecules
344 .iter()
345 .map(|(els, pos)| solve_hf3c(els, pos, config))
346 .collect()
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
354 fn test_default_hf3c_iteration_budget() {
355 assert_eq!(HfConfig::default().max_scf_iter, 300);
356 }
357
358 #[test]
359 fn test_h2_hf3c() {
360 let elements = [1u8, 1];
361 let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]];
362 let config = HfConfig {
363 n_cis_states: 0,
364 ..Default::default()
365 };
366 let result = solve_hf3c(&elements, &positions, &config).unwrap();
367 assert!(result.energy.is_finite(), "Energy should be finite");
368 assert!(result.energy < 0.0, "H2 total energy should be negative");
369 }
370
371 #[test]
372 fn test_water_hf3c() {
373 let elements = [8u8, 1, 1];
374 let positions = [
375 [0.0, 0.0, 0.117],
376 [0.0, 0.757, -0.469],
377 [0.0, -0.757, -0.469],
378 ];
379 let result = solve_hf3c(&elements, &positions, &HfConfig::default()).unwrap();
380 assert!(result.energy.is_finite());
381 assert!(result.orbital_energies.len() == 7); }
383}