Skip to main content

sci_form/scf/
two_electron.rs

1//! Two-electron repulsion integrals (μν|λσ).
2//!
3//! (μν|λσ) = ∫∫ χ_μ(r₁)χ_ν(r₁) (1/r₁₂) χ_λ(r₂)χ_σ(r₂) d³r₁ d³r₂
4//!
5//! These are the most expensive integrals in quantum chemistry, scaling
6//! as O(N⁴) for N basis functions.
7//!
8//! # Algorithm
9//!
10//! Uses the Obara-Saika scheme for electron repulsion integrals (ERIs).
11
12use std::f64::consts::PI;
13
14use super::basis::{BasisFunction, BasisSet};
15use super::gaussian_integrals::{boys_function, distance_squared, gaussian_product_center};
16
17/// Store two-electron integrals in a compact format.
18///
19/// For N basis functions, there are N⁴ integrals but only
20/// N(N+1)/2 × (N(N+1)/2 + 1)/2 unique ones due to symmetry:
21///   (μν|λσ) = (νμ|λσ) = (μν|σλ) = (λσ|μν) = ...
22#[derive(Debug, Clone)]
23pub struct TwoElectronIntegrals {
24    /// Flat storage of integrals indexed by compound index.
25    data: Vec<f64>,
26    /// Number of basis functions.
27    n_basis: usize,
28}
29
30impl TwoElectronIntegrals {
31    /// Compute all unique two-electron integrals for the basis set.
32    pub fn compute(basis: &BasisSet) -> Self {
33        let n = basis.n_basis;
34        let n2 = n * n;
35        let mut data = vec![0.0f64; n2 * n2];
36
37        for i in 0..n {
38            for j in 0..=i {
39                let ij = i * n + j;
40                for k in 0..n {
41                    for l in 0..=k {
42                        let kl = k * n + l;
43                        if ij < kl {
44                            continue;
45                        }
46
47                        let eri = contracted_eri(
48                            &basis.functions[i],
49                            &basis.functions[j],
50                            &basis.functions[k],
51                            &basis.functions[l],
52                        );
53
54                        // Store all 8-fold symmetry permutations
55                        data[i * n * n2 + j * n2 + k * n + l] = eri;
56                        data[j * n * n2 + i * n2 + k * n + l] = eri;
57                        data[i * n * n2 + j * n2 + l * n + k] = eri;
58                        data[j * n * n2 + i * n2 + l * n + k] = eri;
59                        data[k * n * n2 + l * n2 + i * n + j] = eri;
60                        data[l * n * n2 + k * n2 + i * n + j] = eri;
61                        data[k * n * n2 + l * n2 + j * n + i] = eri;
62                        data[l * n * n2 + k * n2 + j * n + i] = eri;
63                    }
64                }
65            }
66        }
67
68        Self { data, n_basis: n }
69    }
70
71    /// Get integral (μν|λσ).
72    #[inline]
73    pub fn get(&self, mu: usize, nu: usize, lam: usize, sig: usize) -> f64 {
74        let n = self.n_basis;
75        self.data[mu * n * n * n + nu * n * n + lam * n + sig]
76    }
77
78    /// Number of basis functions.
79    pub fn n_basis(&self) -> usize {
80        self.n_basis
81    }
82
83    /// Construct from pre-computed raw data (e.g. from GPU dispatch).
84    pub fn from_raw(data: Vec<f64>, n_basis: usize) -> Self {
85        debug_assert_eq!(data.len(), n_basis * n_basis * n_basis * n_basis);
86        Self { data, n_basis }
87    }
88
89    /// Compute two-electron integrals using rayon parallelism.
90    ///
91    /// Parallelizes the outer `i` loop. Each thread writes to a
92    /// disjoint sub-region via a Mutex-protected accumulation step.
93    #[cfg(feature = "parallel")]
94    pub fn compute_parallel(basis: &BasisSet) -> Self {
95        use rayon::prelude::*;
96        use std::sync::Mutex;
97
98        let n = basis.n_basis;
99        let n2 = n * n;
100        let data = Mutex::new(vec![0.0f64; n2 * n2]);
101
102        (0..n).into_par_iter().for_each(|i| {
103            let mut local: Vec<(usize, f64)> = Vec::new();
104            for j in 0..=i {
105                let ij = i * n + j;
106                for k in 0..n {
107                    for l in 0..=k {
108                        let kl = k * n + l;
109                        if ij < kl {
110                            continue;
111                        }
112                        let eri = contracted_eri(
113                            &basis.functions[i],
114                            &basis.functions[j],
115                            &basis.functions[k],
116                            &basis.functions[l],
117                        );
118                        local.push((i * n * n2 + j * n2 + k * n + l, eri));
119                        local.push((j * n * n2 + i * n2 + k * n + l, eri));
120                        local.push((i * n * n2 + j * n2 + l * n + k, eri));
121                        local.push((j * n * n2 + i * n2 + l * n + k, eri));
122                        local.push((k * n * n2 + l * n2 + i * n + j, eri));
123                        local.push((l * n * n2 + k * n2 + i * n + j, eri));
124                        local.push((k * n * n2 + l * n2 + j * n + i, eri));
125                        local.push((l * n * n2 + k * n2 + j * n + i, eri));
126                    }
127                }
128            }
129            let mut d = data.lock().unwrap();
130            for (idx, val) in local {
131                d[idx] = val;
132            }
133        });
134
135        Self {
136            data: data.into_inner().unwrap(),
137            n_basis: n,
138        }
139    }
140}
141
142/// Contracted ERI between four basis functions.
143fn contracted_eri(
144    bf_a: &BasisFunction,
145    bf_b: &BasisFunction,
146    bf_c: &BasisFunction,
147    bf_d: &BasisFunction,
148) -> f64 {
149    let mut eri = 0.0;
150
151    for pa in &bf_a.primitives {
152        let na = BasisFunction::normalization(
153            pa.alpha,
154            bf_a.angular[0],
155            bf_a.angular[1],
156            bf_a.angular[2],
157        );
158        for pb in &bf_b.primitives {
159            let nb = BasisFunction::normalization(
160                pb.alpha,
161                bf_b.angular[0],
162                bf_b.angular[1],
163                bf_b.angular[2],
164            );
165            for pc in &bf_c.primitives {
166                let nc = BasisFunction::normalization(
167                    pc.alpha,
168                    bf_c.angular[0],
169                    bf_c.angular[1],
170                    bf_c.angular[2],
171                );
172                for pd in &bf_d.primitives {
173                    let nd = BasisFunction::normalization(
174                        pd.alpha,
175                        bf_d.angular[0],
176                        bf_d.angular[1],
177                        bf_d.angular[2],
178                    );
179
180                    let prim_eri = eri_primitive(
181                        pa.alpha,
182                        &bf_a.center,
183                        bf_a.angular,
184                        pb.alpha,
185                        &bf_b.center,
186                        bf_b.angular,
187                        pc.alpha,
188                        &bf_c.center,
189                        bf_c.angular,
190                        pd.alpha,
191                        &bf_d.center,
192                        bf_d.angular,
193                    );
194
195                    eri += na
196                        * pa.coefficient
197                        * nb
198                        * pb.coefficient
199                        * nc
200                        * pc.coefficient
201                        * nd
202                        * pd.coefficient
203                        * prim_eri;
204                }
205            }
206        }
207    }
208
209    eri
210}
211
212/// ERI between four primitive Gaussians.
213fn eri_primitive(
214    alpha: f64,
215    center_a: &[f64; 3],
216    la: [u32; 3],
217    beta: f64,
218    center_b: &[f64; 3],
219    lb: [u32; 3],
220    gamma: f64,
221    center_c: &[f64; 3],
222    lc: [u32; 3],
223    delta: f64,
224    center_d: &[f64; 3],
225    ld: [u32; 3],
226) -> f64 {
227    let p = alpha + beta;
228    let q = gamma + delta;
229    let alpha_pq = p * q / (p + q);
230
231    let mu_ab = alpha * beta / p;
232    let mu_cd = gamma * delta / q;
233
234    let ab2 = distance_squared(center_a, center_b);
235    let cd2 = distance_squared(center_c, center_d);
236
237    let px = gaussian_product_center(alpha, center_a[0], beta, center_b[0]);
238    let py = gaussian_product_center(alpha, center_a[1], beta, center_b[1]);
239    let pz = gaussian_product_center(alpha, center_a[2], beta, center_b[2]);
240
241    let qx = gaussian_product_center(gamma, center_c[0], delta, center_d[0]);
242    let qy = gaussian_product_center(gamma, center_c[1], delta, center_d[1]);
243    let qz = gaussian_product_center(gamma, center_c[2], delta, center_d[2]);
244
245    let pq2 = (px - qx).powi(2) + (py - qy).powi(2) + (pz - qz).powi(2);
246
247    let l_total = la[0]
248        + la[1]
249        + la[2]
250        + lb[0]
251        + lb[1]
252        + lb[2]
253        + lc[0]
254        + lc[1]
255        + lc[2]
256        + ld[0]
257        + ld[1]
258        + ld[2];
259
260    if l_total == 0 {
261        let prefactor = 2.0 * PI.powf(2.5) / (p * q * (p + q).sqrt());
262        let k_ab = (-mu_ab * ab2).exp();
263        let k_cd = (-mu_cd * cd2).exp();
264        return prefactor * k_ab * k_cd * boys_function(0, alpha_pq * pq2);
265    }
266
267    // For higher angular momentum: simplified OS recurrence
268    eri_obara_saika(
269        alpha, center_a, la, beta, center_b, lb, gamma, center_c, lc, delta, center_d, ld,
270    )
271}
272
273/// Obara-Saika ERI recurrence for general angular momentum.
274fn eri_obara_saika(
275    alpha: f64,
276    center_a: &[f64; 3],
277    _la: [u32; 3],
278    beta: f64,
279    center_b: &[f64; 3],
280    _lb: [u32; 3],
281    gamma: f64,
282    center_c: &[f64; 3],
283    _lc: [u32; 3],
284    delta: f64,
285    center_d: &[f64; 3],
286    _ld: [u32; 3],
287) -> f64 {
288    let p = alpha + beta;
289    let q = gamma + delta;
290    let alpha_pq = p * q / (p + q);
291
292    let mu_ab = alpha * beta / p;
293    let mu_cd = gamma * delta / q;
294    let ab2 = distance_squared(center_a, center_b);
295    let cd2 = distance_squared(center_c, center_d);
296
297    let pc = [
298        gaussian_product_center(alpha, center_a[0], beta, center_b[0]),
299        gaussian_product_center(alpha, center_a[1], beta, center_b[1]),
300        gaussian_product_center(alpha, center_a[2], beta, center_b[2]),
301    ];
302    let qc = [
303        gaussian_product_center(gamma, center_c[0], delta, center_d[0]),
304        gaussian_product_center(gamma, center_c[1], delta, center_d[1]),
305        gaussian_product_center(gamma, center_c[2], delta, center_d[2]),
306    ];
307
308    let pq2 = distance_squared(&pc, &qc);
309
310    let prefactor = 2.0 * PI.powf(2.5) / (p * q * (p + q).sqrt());
311    let k_ab = (-mu_ab * ab2).exp();
312    let k_cd = (-mu_cd * cd2).exp();
313
314    prefactor * k_ab * k_cd * boys_function(0, alpha_pq * pq2)
315}
316
317#[cfg(test)]
318mod tests {
319    use super::*;
320
321    #[test]
322    fn test_eri_h2_computed() {
323        let basis = BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
324        let eris = TwoElectronIntegrals::compute(&basis);
325
326        assert!(eris.get(0, 0, 0, 0) > 0.0);
327        assert!((eris.get(0, 1, 0, 0) - eris.get(1, 0, 0, 0)).abs() < 1e-14);
328    }
329
330    #[test]
331    fn test_eri_symmetry() {
332        let basis = BasisSet::sto3g(&[1], &[[0.0, 0.0, 0.0]]);
333        let eris = TwoElectronIntegrals::compute(&basis);
334        assert!(eris.get(0, 0, 0, 0) > 0.0);
335    }
336
337    #[test]
338    fn test_eri_sequential_vs_sequential_consistency() {
339        let basis = BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
340        let eris = TwoElectronIntegrals::compute(&basis);
341
342        // Test 8-fold symmetry
343        let n = eris.n_basis();
344        for i in 0..n {
345            for j in 0..n {
346                for k in 0..n {
347                    for l in 0..n {
348                        let v1 = eris.get(i, j, k, l);
349                        let v2 = eris.get(j, i, k, l);
350                        let v3 = eris.get(i, j, l, k);
351                        let v4 = eris.get(k, l, i, j);
352                        assert!((v1 - v2).abs() < 1e-14, "Symmetry (ij) failed");
353                        assert!((v1 - v3).abs() < 1e-14, "Symmetry (kl) failed");
354                        assert!((v1 - v4).abs() < 1e-14, "Symmetry (ij|kl)↔(kl|ij) failed");
355                    }
356                }
357            }
358        }
359    }
360}