Skip to main content

sci_form/hf/
integrals.rs

1//! Electron repulsion integral (ERI) evaluator.
2//!
3//! Implements the Obara-Saika recurrence for two-electron integrals
4//! over Cartesian Gaussian shells. For HF-3c, we need up to (pp|pp).
5
6use super::basis::{BasisSet, Shell, ShellType};
7use super::nuclear::boys_function;
8
9pub fn compute_eris(basis: &BasisSet) -> Vec<f64> {
10    let n = basis.n_basis();
11    let size = eri_storage_size(n);
12    let mut eris = vec![0.0f64; size];
13
14    let shell_offsets = shell_function_offsets(basis);
15    let n_shells = basis.shells.len();
16
17    // Full loop over all shell quartets (no symmetry exploitation)
18    for a in 0..n_shells {
19        for b in 0..n_shells {
20            for c in 0..n_shells {
21                for d in 0..n_shells {
22                    compute_eri_quartet(
23                        &basis.shells[a],
24                        &basis.shells[b],
25                        &basis.shells[c],
26                        &basis.shells[d],
27                        shell_offsets[a],
28                        shell_offsets[b],
29                        shell_offsets[c],
30                        shell_offsets[d],
31                        n,
32                        &mut eris,
33                    );
34                }
35            }
36        }
37    }
38    eris
39}
40
41fn shell_function_offsets(basis: &BasisSet) -> Vec<usize> {
42    let mut offsets = Vec::with_capacity(basis.shells.len());
43    let mut offset = 0;
44    for shell in &basis.shells {
45        offsets.push(offset);
46        offset += shell.n_functions();
47    }
48    offsets
49}
50
51fn eri_storage_size(n: usize) -> usize {
52    let nn = n * (n + 1) / 2;
53    nn * (nn + 1) / 2
54}
55
56pub fn eri_index(i: usize, j: usize, k: usize, l: usize, _n: usize) -> usize {
57    let ij = if i >= j {
58        i * (i + 1) / 2 + j
59    } else {
60        j * (j + 1) / 2 + i
61    };
62    let kl = if k >= l {
63        k * (k + 1) / 2 + l
64    } else {
65        l * (l + 1) / 2 + k
66    };
67    if ij >= kl {
68        ij * (ij + 1) / 2 + kl
69    } else {
70        kl * (kl + 1) / 2 + ij
71    }
72}
73
74pub fn get_eri(eris: &[f64], i: usize, j: usize, k: usize, l: usize, n: usize) -> f64 {
75    eris[eri_index(i, j, k, l, n)]
76}
77
78fn compute_eri_quartet(
79    sa: &Shell,
80    sb: &Shell,
81    sc: &Shell,
82    sd: &Shell,
83    off_a: usize,
84    off_b: usize,
85    off_c: usize,
86    off_d: usize,
87    n_basis: usize,
88    eris: &mut [f64],
89) {
90    let la = if sa.shell_type == ShellType::P { 1 } else { 0 };
91    let lb = if sb.shell_type == ShellType::P { 1 } else { 0 };
92    let lc = if sc.shell_type == ShellType::P { 1 } else { 0 };
93    let ld = if sd.shell_type == ShellType::P { 1 } else { 0 };
94
95    // Temporary accumulator for primitives within this shell quartet
96    let na = sa.n_functions();
97    let nb = sb.n_functions();
98    let nc = sc.n_functions();
99    let nd = sd.n_functions();
100    let mut temp = vec![0.0f64; na * nb * nc * nd];
101
102    for (&ea, &ca) in sa.exponents.iter().zip(&sa.coefficients) {
103        for (&eb, &cb) in sb.exponents.iter().zip(&sb.coefficients) {
104            let zeta = ea + eb;
105            let ab2 = dist_sq(&sa.center, &sb.center);
106            let kab = (-ea * eb / zeta * ab2).exp();
107            let p = gaussian_product(&sa.center, ea, &sb.center, eb);
108
109            for (&ec, &cc) in sc.exponents.iter().zip(&sc.coefficients) {
110                for (&ed, &cd) in sd.exponents.iter().zip(&sd.coefficients) {
111                    let eta = ec + ed;
112                    let cd2 = dist_sq(&sc.center, &sd.center);
113                    let kcd = (-ec * ed / eta * cd2).exp();
114                    let q = gaussian_product(&sc.center, ec, &sd.center, ed);
115
116                    let rho = zeta * eta / (zeta + eta);
117                    let pq2 = dist_sq(&p, &q);
118                    let t = rho * pq2;
119
120                    let prefactor = 2.0 * std::f64::consts::PI.powi(2) / (zeta * eta)
121                        * (std::f64::consts::PI / (zeta + eta)).sqrt()
122                        * kab
123                        * kcd
124                        * ca
125                        * cb
126                        * cc
127                        * cd;
128
129                    let w = gaussian_product(&p, zeta, &q, eta);
130
131                    let prim_eris = os_eri_primitives(
132                        &sa.center, &sb.center, &sc.center, &sd.center, &p, &q, &w, zeta, eta, rho,
133                        t, prefactor, la, lb, lc, ld,
134                    );
135
136                    for fi in 0..na {
137                        for fj in 0..nb {
138                            for fk in 0..nc {
139                                for fl in 0..nd {
140                                    temp[((fi * nb + fj) * nc + fk) * nd + fl] +=
141                                        prim_eris[fi][fj][fk][fl];
142                                }
143                            }
144                        }
145                    }
146                }
147            }
148        }
149    }
150
151    // Write accumulated values to storage (direct assign — no cross-quartet accumulation)
152    for fi in 0..na {
153        for fj in 0..nb {
154            for fk in 0..nc {
155                for fl in 0..nd {
156                    let i = off_a + fi;
157                    let j = off_b + fj;
158                    let k = off_c + fk;
159                    let l = off_d + fl;
160                    let val = temp[((fi * nb + fj) * nc + fk) * nd + fl];
161                    let idx = eri_index(i, j, k, l, n_basis);
162                    eris[idx] = val;
163                }
164            }
165        }
166    }
167}
168
169// Computes all primitive ERI combinations for (s,p) orbitals.
170// Returns a 4D array sized [3][3][3][3]. Indices > 0 are only valid if L_x = 1.
171// Index 0 represents s, indices 1,2,3 represent px,py,pz (shifted to 0,1,2).
172fn os_eri_primitives(
173    a: &[f64; 3],
174    b: &[f64; 3],
175    c: &[f64; 3],
176    d: &[f64; 3],
177    p: &[f64; 3],
178    q: &[f64; 3],
179    w: &[f64; 3],
180    zeta: f64,
181    eta: f64,
182    _rho: f64,
183    t: f64,
184    prefactor: f64,
185    la: usize,
186    lb: usize,
187    lc: usize,
188    ld: usize,
189) -> [[[[f64; 3]; 3]; 3]; 3] {
190    let mut f = [0.0; 5];
191    for m in 0..=(la + lb + lc + ld) {
192        f[m] = boys_function(m, t) * prefactor;
193    }
194
195    // Helper to evaluate OS tree.
196    // ss_ss[m]
197    let ss_ss = |m: usize| f[m];
198
199    // ps_ss[i][m]
200    let mut ps_ss = [[0.0; 5]; 3];
201    if la > 0 || lb > 0 || lc > 0 || ld > 0 {
202        for i in 0..3 {
203            for m in 0..=3 {
204                ps_ss[i][m] = (p[i] - a[i]) * ss_ss(m) + (w[i] - p[i]) * ss_ss(m + 1);
205            }
206        }
207    }
208
209    // pp_ss[i][j][m]
210    let mut pp_ss = [[[0.0; 5]; 3]; 3];
211    if lb > 0 || lc > 0 || ld > 0 {
212        for i in 0..3 {
213            for j in 0..3 {
214                for m in 0..=2 {
215                    pp_ss[i][j][m] = (p[j] - b[j]) * ps_ss[i][m] + (w[j] - p[j]) * ps_ss[i][m + 1];
216                    if i == j {
217                        pp_ss[i][j][m] +=
218                            1.0 / (2.0 * zeta) * (ss_ss(m) - eta / (zeta + eta) * ss_ss(m + 1));
219                    }
220                }
221            }
222        }
223    }
224
225    // ps_ps[i][k][m]
226    let mut ps_ps = [[[0.0; 5]; 3]; 3];
227    if lc > 0 || ld > 0 {
228        for i in 0..3 {
229            for k in 0..3 {
230                for m in 0..=2 {
231                    ps_ps[i][k][m] = (q[k] - c[k]) * ps_ss[i][m] + (w[k] - q[k]) * ps_ss[i][m + 1];
232                    if i == k {
233                        ps_ps[i][k][m] += 1.0 / (2.0 * (zeta + eta)) * ss_ss(m + 1);
234                    }
235                }
236            }
237        }
238    }
239
240    // pp_ps[i][j][k][m]
241    let mut pp_ps = [[[[0.0; 5]; 3]; 3]; 3];
242    if (la > 0 && lb > 0 && lc > 0) || ld > 0 {
243        for i in 0..3 {
244            for j in 0..3 {
245                for k in 0..3 {
246                    for m in 0..=1 {
247                        pp_ps[i][j][k][m] =
248                            (q[k] - c[k]) * pp_ss[i][j][m] + (w[k] - q[k]) * pp_ss[i][j][m + 1];
249                        if i == k {
250                            pp_ps[i][j][k][m] += 1.0 / (2.0 * (zeta + eta)) * ps_ss[j][m + 1];
251                            // ps_ss is (0,b|0,0) which is same as (a,0|0,0) by symmetry if we swap
252                        }
253                        if j == k {
254                            pp_ps[i][j][k][m] += 1.0 / (2.0 * (zeta + eta)) * ps_ss[i][m + 1];
255                        }
256                    }
257                }
258            }
259        }
260    }
261
262    // pp_pp[i][j][k][l][m]
263    let mut pp_pp = [[[[[0.0; 5]; 3]; 3]; 3]; 3];
264    if la > 0 && lb > 0 && lc > 0 && ld > 0 {
265        for i in 0..3 {
266            for j in 0..3 {
267                for k in 0..3 {
268                    for l in 0..3 {
269                        for m in 0..=0 {
270                            pp_pp[i][j][k][l][m] = (q[l] - d[l]) * pp_ps[i][j][k][m]
271                                + (w[l] - q[l]) * pp_ps[i][j][k][m + 1];
272                            if i == l {
273                                pp_pp[i][j][k][l][m] +=
274                                    1.0 / (2.0 * (zeta + eta)) * pp_ps[j][0][k][m + 1];
275                                // Wait, pp_ps[j][0][k] doesn't mean (0,b|c,0).
276                                // Actually, we need (0,b|c,0). By symmetry, it is ps_ps[j][k].
277                            }
278                            if j == l {
279                                pp_pp[i][j][k][l][m] +=
280                                    1.0 / (2.0 * (zeta + eta)) * ps_ps[i][k][m + 1];
281                            }
282                            if k == l {
283                                pp_pp[i][j][k][l][m] += 1.0 / (2.0 * eta)
284                                    * (pp_ss[i][j][m] - zeta / (zeta + eta) * pp_ss[i][j][m + 1]);
285                            }
286                        }
287                    }
288                }
289            }
290        }
291    }
292
293    let mut res = [[[[0.0; 3]; 3]; 3]; 3];
294    let max_i = if la > 0 { 3 } else { 1 };
295    let max_j = if lb > 0 { 3 } else { 1 };
296    let max_k = if lc > 0 { 3 } else { 1 };
297    let max_l = if ld > 0 { 3 } else { 1 };
298
299    for i in 0..max_i {
300        for j in 0..max_j {
301            for k in 0..max_k {
302                for l in 0..max_l {
303                    if la == 0 && lb == 0 && lc == 0 && ld == 0 {
304                        res[i][j][k][l] = ss_ss(0);
305                    } else if la > 0 && lb == 0 && lc == 0 && ld == 0 {
306                        res[i][j][k][l] = ps_ss[i][0];
307                    } else if la == 0 && lb > 0 && lc == 0 && ld == 0 {
308                        // (s, p | s, s) is same as (p, s | s, s) evaluated with B and A swapped.
309                        res[i][j][k][l] = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
310                    } else if la > 0 && lb > 0 && lc == 0 && ld == 0 {
311                        res[i][j][k][l] = pp_ss[i][j][0];
312                    } else if la == 0 && lb == 0 && lc == 0 && ld > 0 {
313                        // (s, s | s, p)
314                        res[i][j][k][l] = (q[l] - d[l]) * ss_ss(0) + (w[l] - q[l]) * ss_ss(1);
315                    } else if la > 0 && lb == 0 && lc > 0 && ld == 0 {
316                        res[i][j][k][l] = ps_ps[i][k][0];
317                    } else if la == 0 && lb == 0 && lc > 0 && ld == 0 {
318                        res[i][j][k][l] = (q[k] - c[k]) * ss_ss(0) + (w[k] - q[k]) * ss_ss(1);
319                    } else if la == 0 && lb == 0 && lc > 0 && ld > 0 {
320                        // (s, s | p, p) is same as (p, p | s, s) exchanging AB with CD.
321                        let mut pp_ss_cd = (q[l] - d[l])
322                            * ((q[k] - c[k]) * ss_ss(0) + (w[k] - q[k]) * ss_ss(1))
323                            + (w[l] - q[l]) * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
324                        if k == l {
325                            pp_ss_cd +=
326                                1.0 / (2.0 * eta) * (ss_ss(0) - zeta / (zeta + eta) * ss_ss(1));
327                        }
328                        res[i][j][k][l] = pp_ss_cd;
329                    } else if la == 0 && lb > 0 && lc > 0 && ld == 0 {
330                        // (s, p | p, s)
331                        let ps_ss_b = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
332                        let ps_ss_b_1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
333                        let mut val = (q[k] - c[k]) * ps_ss_b + (w[k] - q[k]) * ps_ss_b_1;
334                        if j == k {
335                            val += 1.0 / (2.0 * (zeta + eta)) * ss_ss(1);
336                        }
337                        res[i][j][k][l] = val;
338                    } else if la > 0 && lb > 0 && lc > 0 && ld == 0 {
339                        res[i][j][k][l] = pp_ps[i][j][k][0];
340                    } else if la > 0 && lb > 0 && lc == 0 && ld > 0 {
341                        // (p, p | s, p)
342                        // symmetric to pp_ps where c -> d
343                        let mut pp_ps_d =
344                            (q[l] - d[l]) * pp_ss[i][j][0] + (w[l] - q[l]) * pp_ss[i][j][1];
345                        if i == l {
346                            pp_ps_d += 1.0 / (2.0 * (zeta + eta)) * ps_ss[j][1];
347                        }
348                        if j == l {
349                            pp_ps_d += 1.0 / (2.0 * (zeta + eta)) * ps_ss[i][1];
350                        }
351                        res[i][j][k][l] = pp_ps_d;
352                    } else if la > 0 && lb == 0 && lc > 0 && ld > 0 {
353                        // (p, s | p, p)
354                        let mut ps_pp =
355                            (q[l] - d[l]) * ps_ps[i][k][0] + (w[l] - q[l]) * ps_ps[i][k][1];
356                        if i == l {
357                            ps_pp += 1.0 / (2.0 * (zeta + eta))
358                                * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
359                        }
360                        if k == l {
361                            ps_pp += 1.0 / (2.0 * eta)
362                                * (ps_ss[i][0] - zeta / (zeta + eta) * ps_ss[i][1]);
363                        }
364                        res[i][j][k][l] = ps_pp;
365                    } else if la == 0 && lb > 0 && lc > 0 && ld > 0 {
366                        // (s, p | p, p)
367                        let ps_ss_b0 = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
368                        let ps_ss_b1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
369                        let ps_ss_b2 = (p[j] - b[j]) * ss_ss(2) + (w[j] - p[j]) * ss_ss(3);
370                        let ps_ps_bk0 = (q[k] - c[k]) * ps_ss_b0
371                            + (w[k] - q[k]) * ps_ss_b1
372                            + if j == k {
373                                1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
374                            } else {
375                                0.0
376                            };
377                        let ps_ps_bk1 = (q[k] - c[k]) * ps_ss_b1
378                            + (w[k] - q[k]) * ps_ss_b2
379                            + if j == k {
380                                1.0 / (2.0 * (zeta + eta)) * ss_ss(2)
381                            } else {
382                                0.0
383                            };
384                        let mut sp_pp = (q[l] - d[l]) * ps_ps_bk0 + (w[l] - q[l]) * ps_ps_bk1;
385                        if j == l {
386                            sp_pp += 1.0 / (2.0 * (zeta + eta))
387                                * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
388                        }
389                        if k == l {
390                            sp_pp +=
391                                1.0 / (2.0 * eta) * (ps_ss_b0 - zeta / (zeta + eta) * ps_ss_b1);
392                        }
393                        res[i][j][k][l] = sp_pp;
394                    } else if la > 0 && lb > 0 && lc > 0 && ld > 0 {
395                        // Fully populated pp_pp, let's substitute the symmetry correct components
396                        let mut term =
397                            (q[l] - d[l]) * pp_ps[i][j][k][0] + (w[l] - q[l]) * pp_ps[i][j][k][1];
398                        if i == l {
399                            term += 1.0 / (2.0 * (zeta + eta)) * ps_ps[j][k][1];
400                        }
401                        if j == l {
402                            term += 1.0 / (2.0 * (zeta + eta)) * ps_ps[i][k][1];
403                        }
404                        if k == l {
405                            term += 1.0 / (2.0 * eta)
406                                * (pp_ss[i][j][0] - zeta / (zeta + eta) * pp_ss[i][j][1]);
407                        }
408                        res[i][j][k][l] = term;
409                    } else if la > 0 && lb == 0 && lc == 0 && ld > 0 {
410                        // (p, s | s, p)
411                        let val = (q[l] - d[l]) * ps_ss[i][0]
412                            + (w[l] - q[l]) * ps_ss[i][1]
413                            + if i == l {
414                                1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
415                            } else {
416                                0.0
417                            };
418                        res[i][j][k][l] = val;
419                    } else if la == 0 && lb > 0 && lc == 0 && ld > 0 {
420                        // (s, p | s, p)
421                        let ps_ss_b = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
422                        let ps_ss_b1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
423                        res[i][j][k][l] = (q[l] - d[l]) * ps_ss_b
424                            + (w[l] - q[l]) * ps_ss_b1
425                            + if j == l {
426                                1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
427                            } else {
428                                0.0
429                            };
430                    }
431                }
432            }
433        }
434    }
435    res
436}
437
438fn gaussian_product(a: &[f64; 3], ea: f64, b: &[f64; 3], eb: f64) -> [f64; 3] {
439    let g = ea + eb;
440    [
441        (ea * a[0] + eb * b[0]) / g,
442        (ea * a[1] + eb * b[1]) / g,
443        (ea * a[2] + eb * b[2]) / g,
444    ]
445}
446
447#[inline]
448fn dist_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
449    let dx = a[0] - b[0];
450    let dy = a[1] - b[1];
451    let dz = a[2] - b[2];
452    dx * dx + dy * dy + dz * dz
453}
454
455#[cfg(test)]
456mod tests {
457    use super::super::basis::build_sto3g_basis;
458    use super::*;
459
460    #[test]
461    fn test_eri_h2() {
462        let basis = build_sto3g_basis(&[1, 1], &[[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]]);
463        let eris = compute_eris(&basis);
464        let n = basis.n_basis();
465        // (11|11) should be positive (electron-electron repulsion)
466        let eri_0000 = get_eri(&eris, 0, 0, 0, 0, n);
467        assert!(eri_0000 > 0.0, "(11|11) = {eri_0000}");
468    }
469
470    #[test]
471    fn test_eri_symmetry() {
472        let basis = build_sto3g_basis(&[1, 1], &[[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]]);
473        let eris = compute_eris(&basis);
474        let n = basis.n_basis();
475        // Permutational symmetry
476        assert_eq!(get_eri(&eris, 0, 1, 0, 1, n), get_eri(&eris, 1, 0, 1, 0, n));
477        assert_eq!(get_eri(&eris, 0, 1, 0, 1, n), get_eri(&eris, 0, 1, 1, 0, n));
478    }
479}