Skip to main content

sci_form/hf/
api.rs

1//! Public API for HF-3c composite method.
2//!
3//! Ties together: SCF (Hartree-Fock) + D3 + gCP + SRB corrections,
4//! plus optional CIS excited-state calculation for UV-Vis spectroscopy.
5
6use super::basis::{build_sto3g_basis, ANG_TO_BOHR};
7use super::cis::{compute_cis_with_dipole, CisResult};
8use super::d3::compute_d3_energy;
9use super::fock::nuclear_repulsion;
10use super::gcp::compute_gcp;
11use super::integrals::compute_eris;
12use super::nuclear::compute_nuclear_matrix;
13use super::overlap_kin::{compute_kinetic_matrix, compute_overlap_matrix};
14use super::scf::{solve_scf, ScfConfig};
15use super::srb::compute_srb;
16use serde::{Deserialize, Serialize};
17
18/// Configuration for HF-3c calculation.
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct HfConfig {
21    /// Maximum SCF iterations.
22    pub max_scf_iter: usize,
23    /// DIIS subspace size.
24    pub diis_size: usize,
25    /// Number of CIS excited states to compute (0 = skip CIS).
26    pub n_cis_states: usize,
27    /// Include empirical corrections (D3, gCP, SRB).
28    pub corrections: bool,
29}
30
31impl Default for HfConfig {
32    fn default() -> Self {
33        HfConfig {
34            max_scf_iter: 300,
35            diis_size: 6,
36            n_cis_states: 5,
37            corrections: true,
38        }
39    }
40}
41
42/// Result of an HF-3c calculation.
43///
44/// Energy breakdown: `energy = hf_energy + d3_energy + gcp_energy + srb_energy`
45/// where `hf_energy = electronic + nuclear_repulsion`.
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct Hf3cResult {
48    /// Total HF-3c energy (Hartree).
49    pub energy: f64,
50    /// Pure HF electronic energy.
51    pub hf_energy: f64,
52    /// Nuclear repulsion energy.
53    pub nuclear_repulsion: f64,
54    /// D3 dispersion correction energy.
55    pub d3_energy: f64,
56    /// gCP BSSE correction energy.
57    pub gcp_energy: f64,
58    /// SRB short-range correction energy.
59    pub srb_energy: f64,
60    /// Orbital energies (sorted, eV).
61    pub orbital_energies: Vec<f64>,
62    /// Number of SCF iterations.
63    pub scf_iterations: usize,
64    /// Whether SCF converged.
65    pub converged: bool,
66    /// CIS excitation results (if requested).
67    pub cis: Option<CisResult>,
68    /// Number of basis functions.
69    pub n_basis: usize,
70    /// Number of electrons.
71    pub n_electrons: usize,
72    /// HOMO energy (eV).
73    pub homo_energy: f64,
74    /// LUMO energy (eV), if available.
75    pub lumo_energy: Option<f64>,
76    /// HOMO–LUMO gap (eV).
77    pub gap: f64,
78    /// Mulliken charges per atom.
79    pub mulliken_charges: Vec<f64>,
80}
81
82#[cfg(feature = "experimental-gpu")]
83fn hf_basis_to_gpu_basis(basis: &super::basis::BasisSet) -> crate::scf::basis::BasisSet {
84    use crate::scf::basis::{
85        BasisFunction as GpuBasisFunction, BasisSet as GpuBasisSet,
86        ContractedShell as GpuContractedShell, GaussianPrimitive,
87    };
88
89    let mut functions = Vec::new();
90    let mut shells = Vec::new();
91    let mut function_to_atom = Vec::new();
92
93    for shell in &basis.shells {
94        let primitives: Vec<GaussianPrimitive> = shell
95            .exponents
96            .iter()
97            .zip(shell.coefficients.iter())
98            .map(|(&alpha, &coefficient)| GaussianPrimitive { alpha, coefficient })
99            .collect();
100
101        let l = match shell.shell_type {
102            super::basis::ShellType::S => 0,
103            super::basis::ShellType::P => 1,
104        };
105
106        shells.push(GpuContractedShell {
107            atom_index: shell.center_idx,
108            center: shell.center,
109            l,
110            primitives: primitives.clone(),
111        });
112
113        match shell.shell_type {
114            super::basis::ShellType::S => {
115                functions.push(GpuBasisFunction {
116                    atom_index: shell.center_idx,
117                    center: shell.center,
118                    angular: [0, 0, 0],
119                    l_total: 0,
120                    primitives: primitives.clone(),
121                });
122                function_to_atom.push(shell.center_idx);
123            }
124            super::basis::ShellType::P => {
125                for angular in [[1, 0, 0], [0, 1, 0], [0, 0, 1]] {
126                    functions.push(GpuBasisFunction {
127                        atom_index: shell.center_idx,
128                        center: shell.center,
129                        angular,
130                        l_total: 1,
131                        primitives: primitives.clone(),
132                    });
133                    function_to_atom.push(shell.center_idx);
134                }
135            }
136        }
137    }
138
139    let n_basis = functions.len();
140    GpuBasisSet {
141        functions,
142        shells,
143        n_basis,
144        function_to_atom,
145    }
146}
147
148/// Run a complete HF-3c calculation.
149pub fn solve_hf3c(
150    elements: &[u8],
151    positions: &[[f64; 3]],
152    config: &HfConfig,
153) -> Result<Hf3cResult, String> {
154    if elements.len() != positions.len() {
155        return Err("elements/positions length mismatch".to_string());
156    }
157    if elements.is_empty() {
158        return Err("empty molecule".to_string());
159    }
160
161    // Convert positions to Bohr
162    let pos_bohr: Vec<[f64; 3]> = positions
163        .iter()
164        .map(|p| [p[0] * ANG_TO_BOHR, p[1] * ANG_TO_BOHR, p[2] * ANG_TO_BOHR])
165        .collect();
166
167    // Build basis set
168    let basis = build_sto3g_basis(elements, positions);
169    let n_basis = basis.n_basis();
170
171    // Count electrons
172    let n_electrons: usize = elements.iter().map(|&z| z as usize).sum();
173
174    // One-electron integrals
175    let s_mat = compute_overlap_matrix(&basis);
176    let t_mat = compute_kinetic_matrix(&basis);
177    let v_mat = compute_nuclear_matrix(&basis, elements, &pos_bohr);
178    let h_core = &t_mat + &v_mat;
179
180    // Two-electron integrals
181    let eris = compute_eris(&basis);
182
183    #[cfg(feature = "experimental-gpu")]
184    let gpu_eris_full = if n_basis >= 4 {
185        // Memory check: full N⁴ tensor requires n_basis⁴ × 8 bytes.
186        // Cap at ~512 MB to prevent OOM.
187        let n4 = (n_basis as u64)
188            .saturating_mul(n_basis as u64)
189            .saturating_mul(n_basis as u64)
190            .saturating_mul(n_basis as u64);
191        let mem_bytes = n4.saturating_mul(8);
192        let max_mem: u64 = 512 * 1024 * 1024; // 512 MB
193
194        if mem_bytes > max_mem {
195            None // Too large for dense tensor; fall back to CPU packed ERIs
196        } else if let Ok(ctx) = crate::gpu::context::GpuContext::try_create() {
197            let gpu_basis = hf_basis_to_gpu_basis(&basis);
198            crate::gpu::two_electron_gpu::compute_eris_gpu(&ctx, &gpu_basis)
199                .ok()
200                .map(|gpu_eris| {
201                    let cap = n_basis * n_basis * n_basis * n_basis;
202                    let mut full = Vec::with_capacity(cap);
203                    for mu in 0..n_basis {
204                        for nu in 0..n_basis {
205                            for lam in 0..n_basis {
206                                for sig in 0..n_basis {
207                                    full.push(gpu_eris.get(mu, nu, lam, sig));
208                                }
209                            }
210                        }
211                    }
212                    full
213                })
214        } else {
215            None
216        }
217    } else {
218        None
219    };
220
221    #[cfg(not(feature = "experimental-gpu"))]
222    let gpu_eris_full: Option<Vec<f64>> = None;
223
224    // SCF
225    let scf_config = ScfConfig {
226        max_iter: config.max_scf_iter,
227        diis_size: config.diis_size,
228        ..ScfConfig::default()
229    };
230    let scf_result = solve_scf(
231        &h_core,
232        &s_mat,
233        &eris,
234        gpu_eris_full.as_deref(),
235        n_electrons,
236        &scf_config,
237    );
238
239    // Nuclear repulsion
240    let e_nuc = nuclear_repulsion(elements, &pos_bohr);
241
242    // Empirical corrections
243    let (d3_e, gcp_e, srb_e) = if config.corrections {
244        (
245            compute_d3_energy(elements, &pos_bohr).energy,
246            compute_gcp(elements, &pos_bohr),
247            compute_srb(elements, &pos_bohr),
248        )
249    } else {
250        (0.0, 0.0, 0.0)
251    };
252
253    let total = scf_result.energy + e_nuc + d3_e + gcp_e + srb_e;
254
255    // CIS excited states
256    let cis = if config.n_cis_states > 0 && scf_result.converged {
257        let n_occ = n_electrons / 2;
258        let ao_map = super::basis::ao_to_atom_map(&basis);
259        Some(compute_cis_with_dipole(
260            &scf_result.orbital_energies,
261            &scf_result.coefficients,
262            &eris,
263            n_basis,
264            n_occ,
265            config.n_cis_states,
266            Some(&pos_bohr),
267            Some(&ao_map),
268        ))
269    } else {
270        None
271    };
272
273    // Extract HOMO/LUMO from orbital energies
274    let n_occ = n_electrons / 2;
275    let homo_energy = if n_occ > 0 && n_occ <= scf_result.orbital_energies.len() {
276        scf_result.orbital_energies[n_occ - 1]
277    } else {
278        0.0
279    };
280    let lumo_energy = if n_occ < scf_result.orbital_energies.len() {
281        Some(scf_result.orbital_energies[n_occ])
282    } else {
283        None
284    };
285    let gap = lumo_energy.map_or(0.0, |l| l - homo_energy);
286
287    // Mulliken charges from converged density
288    let mulliken_charges = if scf_result.converged {
289        let ps = &scf_result.density * &s_mat;
290        let ao_to_atom = super::basis::ao_to_atom_map(&basis);
291        let mut charges = vec![0.0_f64; elements.len()];
292        for mu in 0..n_basis {
293            charges[ao_to_atom[mu]] += ps[(mu, mu)];
294        }
295        charges
296            .iter()
297            .enumerate()
298            .map(|(i, &pop)| elements[i] as f64 - pop)
299            .collect()
300    } else {
301        vec![0.0; elements.len()]
302    };
303
304    Ok(Hf3cResult {
305        energy: total,
306        hf_energy: scf_result.energy + e_nuc,
307        nuclear_repulsion: e_nuc,
308        d3_energy: d3_e,
309        gcp_energy: gcp_e,
310        srb_energy: srb_e,
311        orbital_energies: scf_result.orbital_energies,
312        scf_iterations: scf_result.iterations,
313        converged: scf_result.converged,
314        cis,
315        n_basis,
316        n_electrons,
317        homo_energy,
318        lumo_energy,
319        gap,
320        mulliken_charges,
321    })
322}
323
324/// Batch-process multiple HF-3c calculations in parallel.
325#[cfg(feature = "parallel")]
326pub fn solve_hf3c_batch(
327    molecules: &[(&[u8], &[[f64; 3]])],
328    config: &HfConfig,
329) -> Vec<Result<Hf3cResult, String>> {
330    use rayon::prelude::*;
331    molecules
332        .par_iter()
333        .map(|(els, pos)| solve_hf3c(els, pos, config))
334        .collect()
335}
336
337/// Batch-process multiple HF-3c calculations sequentially.
338#[cfg(not(feature = "parallel"))]
339pub fn solve_hf3c_batch(
340    molecules: &[(&[u8], &[[f64; 3]])],
341    config: &HfConfig,
342) -> Vec<Result<Hf3cResult, String>> {
343    molecules
344        .iter()
345        .map(|(els, pos)| solve_hf3c(els, pos, config))
346        .collect()
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_default_hf3c_iteration_budget() {
355        assert_eq!(HfConfig::default().max_scf_iter, 300);
356    }
357
358    #[test]
359    fn test_h2_hf3c() {
360        let elements = [1u8, 1];
361        let positions = [[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]];
362        let config = HfConfig {
363            n_cis_states: 0,
364            ..Default::default()
365        };
366        let result = solve_hf3c(&elements, &positions, &config).unwrap();
367        assert!(result.energy.is_finite(), "Energy should be finite");
368        assert!(result.energy < 0.0, "H2 total energy should be negative");
369    }
370
371    #[test]
372    fn test_water_hf3c() {
373        let elements = [8u8, 1, 1];
374        let positions = [
375            [0.0, 0.0, 0.117],
376            [0.0, 0.757, -0.469],
377            [0.0, -0.757, -0.469],
378        ];
379        let result = solve_hf3c(&elements, &positions, &HfConfig::default()).unwrap();
380        assert!(result.energy.is_finite());
381        assert!(result.orbital_energies.len() == 7); // 7 basis functions
382    }
383}