1use super::basis::{BasisSet, Shell, ShellType};
7use super::nuclear::boys_function;
8
9pub fn compute_eris(basis: &BasisSet) -> Vec<f64> {
10 let n = basis.n_basis();
11 let size = eri_storage_size(n);
12 let mut eris = vec![0.0f64; size];
13
14 let shell_offsets = shell_function_offsets(basis);
15 let n_shells = basis.shells.len();
16
17 for a in 0..n_shells {
19 for b in 0..n_shells {
20 for c in 0..n_shells {
21 for d in 0..n_shells {
22 compute_eri_quartet(
23 &basis.shells[a],
24 &basis.shells[b],
25 &basis.shells[c],
26 &basis.shells[d],
27 shell_offsets[a],
28 shell_offsets[b],
29 shell_offsets[c],
30 shell_offsets[d],
31 n,
32 &mut eris,
33 );
34 }
35 }
36 }
37 }
38 eris
39}
40
41fn shell_function_offsets(basis: &BasisSet) -> Vec<usize> {
42 let mut offsets = Vec::with_capacity(basis.shells.len());
43 let mut offset = 0;
44 for shell in &basis.shells {
45 offsets.push(offset);
46 offset += shell.n_functions();
47 }
48 offsets
49}
50
51fn eri_storage_size(n: usize) -> usize {
52 let nn = n * (n + 1) / 2;
53 nn * (nn + 1) / 2
54}
55
56pub fn eri_index(i: usize, j: usize, k: usize, l: usize, _n: usize) -> usize {
57 let ij = if i >= j {
58 i * (i + 1) / 2 + j
59 } else {
60 j * (j + 1) / 2 + i
61 };
62 let kl = if k >= l {
63 k * (k + 1) / 2 + l
64 } else {
65 l * (l + 1) / 2 + k
66 };
67 if ij >= kl {
68 ij * (ij + 1) / 2 + kl
69 } else {
70 kl * (kl + 1) / 2 + ij
71 }
72}
73
74pub fn get_eri(eris: &[f64], i: usize, j: usize, k: usize, l: usize, n: usize) -> f64 {
75 eris[eri_index(i, j, k, l, n)]
76}
77
78fn compute_eri_quartet(
79 sa: &Shell,
80 sb: &Shell,
81 sc: &Shell,
82 sd: &Shell,
83 off_a: usize,
84 off_b: usize,
85 off_c: usize,
86 off_d: usize,
87 n_basis: usize,
88 eris: &mut [f64],
89) {
90 let la = if sa.shell_type == ShellType::P { 1 } else { 0 };
91 let lb = if sb.shell_type == ShellType::P { 1 } else { 0 };
92 let lc = if sc.shell_type == ShellType::P { 1 } else { 0 };
93 let ld = if sd.shell_type == ShellType::P { 1 } else { 0 };
94
95 let na = sa.n_functions();
97 let nb = sb.n_functions();
98 let nc = sc.n_functions();
99 let nd = sd.n_functions();
100 let mut temp = vec![0.0f64; na * nb * nc * nd];
101
102 for (&ea, &ca) in sa.exponents.iter().zip(&sa.coefficients) {
103 for (&eb, &cb) in sb.exponents.iter().zip(&sb.coefficients) {
104 let zeta = ea + eb;
105 let ab2 = dist_sq(&sa.center, &sb.center);
106 let kab = (-ea * eb / zeta * ab2).exp();
107 let p = gaussian_product(&sa.center, ea, &sb.center, eb);
108
109 for (&ec, &cc) in sc.exponents.iter().zip(&sc.coefficients) {
110 for (&ed, &cd) in sd.exponents.iter().zip(&sd.coefficients) {
111 let eta = ec + ed;
112 let cd2 = dist_sq(&sc.center, &sd.center);
113 let kcd = (-ec * ed / eta * cd2).exp();
114 let q = gaussian_product(&sc.center, ec, &sd.center, ed);
115
116 let rho = zeta * eta / (zeta + eta);
117 let pq2 = dist_sq(&p, &q);
118 let t = rho * pq2;
119
120 let prefactor = 2.0 * std::f64::consts::PI.powi(2) / (zeta * eta)
121 * (std::f64::consts::PI / (zeta + eta)).sqrt()
122 * kab
123 * kcd
124 * ca
125 * cb
126 * cc
127 * cd;
128
129 let w = gaussian_product(&p, zeta, &q, eta);
130
131 let prim_eris = os_eri_primitives(
132 &sa.center, &sb.center, &sc.center, &sd.center, &p, &q, &w, zeta, eta, rho,
133 t, prefactor, la, lb, lc, ld,
134 );
135
136 for fi in 0..na {
137 for fj in 0..nb {
138 for fk in 0..nc {
139 for fl in 0..nd {
140 temp[((fi * nb + fj) * nc + fk) * nd + fl] +=
141 prim_eris[fi][fj][fk][fl];
142 }
143 }
144 }
145 }
146 }
147 }
148 }
149 }
150
151 for fi in 0..na {
153 for fj in 0..nb {
154 for fk in 0..nc {
155 for fl in 0..nd {
156 let i = off_a + fi;
157 let j = off_b + fj;
158 let k = off_c + fk;
159 let l = off_d + fl;
160 let val = temp[((fi * nb + fj) * nc + fk) * nd + fl];
161 let idx = eri_index(i, j, k, l, n_basis);
162 eris[idx] = val;
163 }
164 }
165 }
166 }
167}
168
169fn os_eri_primitives(
173 a: &[f64; 3],
174 b: &[f64; 3],
175 c: &[f64; 3],
176 d: &[f64; 3],
177 p: &[f64; 3],
178 q: &[f64; 3],
179 w: &[f64; 3],
180 zeta: f64,
181 eta: f64,
182 _rho: f64,
183 t: f64,
184 prefactor: f64,
185 la: usize,
186 lb: usize,
187 lc: usize,
188 ld: usize,
189) -> [[[[f64; 3]; 3]; 3]; 3] {
190 let mut f = [0.0; 5];
191 for m in 0..=(la + lb + lc + ld) {
192 f[m] = boys_function(m, t) * prefactor;
193 }
194
195 let ss_ss = |m: usize| f[m];
198
199 let mut ps_ss = [[0.0; 5]; 3];
201 if la > 0 || lb > 0 || lc > 0 || ld > 0 {
202 for i in 0..3 {
203 for m in 0..=3 {
204 ps_ss[i][m] = (p[i] - a[i]) * ss_ss(m) + (w[i] - p[i]) * ss_ss(m + 1);
205 }
206 }
207 }
208
209 let mut pp_ss = [[[0.0; 5]; 3]; 3];
211 if lb > 0 || lc > 0 || ld > 0 {
212 for i in 0..3 {
213 for j in 0..3 {
214 for m in 0..=2 {
215 pp_ss[i][j][m] = (p[j] - b[j]) * ps_ss[i][m] + (w[j] - p[j]) * ps_ss[i][m + 1];
216 if i == j {
217 pp_ss[i][j][m] +=
218 1.0 / (2.0 * zeta) * (ss_ss(m) - eta / (zeta + eta) * ss_ss(m + 1));
219 }
220 }
221 }
222 }
223 }
224
225 let mut ps_ps = [[[0.0; 5]; 3]; 3];
227 if lc > 0 || ld > 0 {
228 for i in 0..3 {
229 for k in 0..3 {
230 for m in 0..=2 {
231 ps_ps[i][k][m] = (q[k] - c[k]) * ps_ss[i][m] + (w[k] - q[k]) * ps_ss[i][m + 1];
232 if i == k {
233 ps_ps[i][k][m] += 1.0 / (2.0 * (zeta + eta)) * ss_ss(m + 1);
234 }
235 }
236 }
237 }
238 }
239
240 let mut pp_ps = [[[[0.0; 5]; 3]; 3]; 3];
242 if (la > 0 && lb > 0 && lc > 0) || ld > 0 {
243 for i in 0..3 {
244 for j in 0..3 {
245 for k in 0..3 {
246 for m in 0..=1 {
247 pp_ps[i][j][k][m] =
248 (q[k] - c[k]) * pp_ss[i][j][m] + (w[k] - q[k]) * pp_ss[i][j][m + 1];
249 if i == k {
250 pp_ps[i][j][k][m] += 1.0 / (2.0 * (zeta + eta)) * ps_ss[j][m + 1];
251 }
253 if j == k {
254 pp_ps[i][j][k][m] += 1.0 / (2.0 * (zeta + eta)) * ps_ss[i][m + 1];
255 }
256 }
257 }
258 }
259 }
260 }
261
262 let mut pp_pp = [[[[[0.0; 5]; 3]; 3]; 3]; 3];
264 if la > 0 && lb > 0 && lc > 0 && ld > 0 {
265 for i in 0..3 {
266 for j in 0..3 {
267 for k in 0..3 {
268 for l in 0..3 {
269 for m in 0..=0 {
270 pp_pp[i][j][k][l][m] = (q[l] - d[l]) * pp_ps[i][j][k][m]
271 + (w[l] - q[l]) * pp_ps[i][j][k][m + 1];
272 if i == l {
273 pp_pp[i][j][k][l][m] +=
274 1.0 / (2.0 * (zeta + eta)) * pp_ps[j][0][k][m + 1];
275 }
278 if j == l {
279 pp_pp[i][j][k][l][m] +=
280 1.0 / (2.0 * (zeta + eta)) * ps_ps[i][k][m + 1];
281 }
282 if k == l {
283 pp_pp[i][j][k][l][m] += 1.0 / (2.0 * eta)
284 * (pp_ss[i][j][m] - zeta / (zeta + eta) * pp_ss[i][j][m + 1]);
285 }
286 }
287 }
288 }
289 }
290 }
291 }
292
293 let mut res = [[[[0.0; 3]; 3]; 3]; 3];
294 let max_i = if la > 0 { 3 } else { 1 };
295 let max_j = if lb > 0 { 3 } else { 1 };
296 let max_k = if lc > 0 { 3 } else { 1 };
297 let max_l = if ld > 0 { 3 } else { 1 };
298
299 for i in 0..max_i {
300 for j in 0..max_j {
301 for k in 0..max_k {
302 for l in 0..max_l {
303 if la == 0 && lb == 0 && lc == 0 && ld == 0 {
304 res[i][j][k][l] = ss_ss(0);
305 } else if la > 0 && lb == 0 && lc == 0 && ld == 0 {
306 res[i][j][k][l] = ps_ss[i][0];
307 } else if la == 0 && lb > 0 && lc == 0 && ld == 0 {
308 res[i][j][k][l] = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
310 } else if la > 0 && lb > 0 && lc == 0 && ld == 0 {
311 res[i][j][k][l] = pp_ss[i][j][0];
312 } else if la == 0 && lb == 0 && lc == 0 && ld > 0 {
313 res[i][j][k][l] = (q[l] - d[l]) * ss_ss(0) + (w[l] - q[l]) * ss_ss(1);
315 } else if la > 0 && lb == 0 && lc > 0 && ld == 0 {
316 res[i][j][k][l] = ps_ps[i][k][0];
317 } else if la == 0 && lb == 0 && lc > 0 && ld == 0 {
318 res[i][j][k][l] = (q[k] - c[k]) * ss_ss(0) + (w[k] - q[k]) * ss_ss(1);
319 } else if la == 0 && lb == 0 && lc > 0 && ld > 0 {
320 let mut pp_ss_cd = (q[l] - d[l])
322 * ((q[k] - c[k]) * ss_ss(0) + (w[k] - q[k]) * ss_ss(1))
323 + (w[l] - q[l]) * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
324 if k == l {
325 pp_ss_cd +=
326 1.0 / (2.0 * eta) * (ss_ss(0) - zeta / (zeta + eta) * ss_ss(1));
327 }
328 res[i][j][k][l] = pp_ss_cd;
329 } else if la == 0 && lb > 0 && lc > 0 && ld == 0 {
330 let ps_ss_b = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
332 let ps_ss_b_1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
333 let mut val = (q[k] - c[k]) * ps_ss_b + (w[k] - q[k]) * ps_ss_b_1;
334 if j == k {
335 val += 1.0 / (2.0 * (zeta + eta)) * ss_ss(1);
336 }
337 res[i][j][k][l] = val;
338 } else if la > 0 && lb > 0 && lc > 0 && ld == 0 {
339 res[i][j][k][l] = pp_ps[i][j][k][0];
340 } else if la > 0 && lb > 0 && lc == 0 && ld > 0 {
341 let mut pp_ps_d =
344 (q[l] - d[l]) * pp_ss[i][j][0] + (w[l] - q[l]) * pp_ss[i][j][1];
345 if i == l {
346 pp_ps_d += 1.0 / (2.0 * (zeta + eta)) * ps_ss[j][1];
347 }
348 if j == l {
349 pp_ps_d += 1.0 / (2.0 * (zeta + eta)) * ps_ss[i][1];
350 }
351 res[i][j][k][l] = pp_ps_d;
352 } else if la > 0 && lb == 0 && lc > 0 && ld > 0 {
353 let mut ps_pp =
355 (q[l] - d[l]) * ps_ps[i][k][0] + (w[l] - q[l]) * ps_ps[i][k][1];
356 if i == l {
357 ps_pp += 1.0 / (2.0 * (zeta + eta))
358 * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
359 }
360 if k == l {
361 ps_pp += 1.0 / (2.0 * eta)
362 * (ps_ss[i][0] - zeta / (zeta + eta) * ps_ss[i][1]);
363 }
364 res[i][j][k][l] = ps_pp;
365 } else if la == 0 && lb > 0 && lc > 0 && ld > 0 {
366 let ps_ss_b0 = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
368 let ps_ss_b1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
369 let ps_ss_b2 = (p[j] - b[j]) * ss_ss(2) + (w[j] - p[j]) * ss_ss(3);
370 let ps_ps_bk0 = (q[k] - c[k]) * ps_ss_b0
371 + (w[k] - q[k]) * ps_ss_b1
372 + if j == k {
373 1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
374 } else {
375 0.0
376 };
377 let ps_ps_bk1 = (q[k] - c[k]) * ps_ss_b1
378 + (w[k] - q[k]) * ps_ss_b2
379 + if j == k {
380 1.0 / (2.0 * (zeta + eta)) * ss_ss(2)
381 } else {
382 0.0
383 };
384 let mut sp_pp = (q[l] - d[l]) * ps_ps_bk0 + (w[l] - q[l]) * ps_ps_bk1;
385 if j == l {
386 sp_pp += 1.0 / (2.0 * (zeta + eta))
387 * ((q[k] - c[k]) * ss_ss(1) + (w[k] - q[k]) * ss_ss(2));
388 }
389 if k == l {
390 sp_pp +=
391 1.0 / (2.0 * eta) * (ps_ss_b0 - zeta / (zeta + eta) * ps_ss_b1);
392 }
393 res[i][j][k][l] = sp_pp;
394 } else if la > 0 && lb > 0 && lc > 0 && ld > 0 {
395 let mut term =
397 (q[l] - d[l]) * pp_ps[i][j][k][0] + (w[l] - q[l]) * pp_ps[i][j][k][1];
398 if i == l {
399 term += 1.0 / (2.0 * (zeta + eta)) * ps_ps[j][k][1];
400 }
401 if j == l {
402 term += 1.0 / (2.0 * (zeta + eta)) * ps_ps[i][k][1];
403 }
404 if k == l {
405 term += 1.0 / (2.0 * eta)
406 * (pp_ss[i][j][0] - zeta / (zeta + eta) * pp_ss[i][j][1]);
407 }
408 res[i][j][k][l] = term;
409 } else if la > 0 && lb == 0 && lc == 0 && ld > 0 {
410 let val = (q[l] - d[l]) * ps_ss[i][0]
412 + (w[l] - q[l]) * ps_ss[i][1]
413 + if i == l {
414 1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
415 } else {
416 0.0
417 };
418 res[i][j][k][l] = val;
419 } else if la == 0 && lb > 0 && lc == 0 && ld > 0 {
420 let ps_ss_b = (p[j] - b[j]) * ss_ss(0) + (w[j] - p[j]) * ss_ss(1);
422 let ps_ss_b1 = (p[j] - b[j]) * ss_ss(1) + (w[j] - p[j]) * ss_ss(2);
423 res[i][j][k][l] = (q[l] - d[l]) * ps_ss_b
424 + (w[l] - q[l]) * ps_ss_b1
425 + if j == l {
426 1.0 / (2.0 * (zeta + eta)) * ss_ss(1)
427 } else {
428 0.0
429 };
430 }
431 }
432 }
433 }
434 }
435 res
436}
437
438fn gaussian_product(a: &[f64; 3], ea: f64, b: &[f64; 3], eb: f64) -> [f64; 3] {
439 let g = ea + eb;
440 [
441 (ea * a[0] + eb * b[0]) / g,
442 (ea * a[1] + eb * b[1]) / g,
443 (ea * a[2] + eb * b[2]) / g,
444 ]
445}
446
447#[inline]
448fn dist_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
449 let dx = a[0] - b[0];
450 let dy = a[1] - b[1];
451 let dz = a[2] - b[2];
452 dx * dx + dy * dy + dz * dz
453}
454
455#[cfg(test)]
456mod tests {
457 use super::super::basis::build_sto3g_basis;
458 use super::*;
459
460 #[test]
461 fn test_eri_h2() {
462 let basis = build_sto3g_basis(&[1, 1], &[[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]]);
463 let eris = compute_eris(&basis);
464 let n = basis.n_basis();
465 let eri_0000 = get_eri(&eris, 0, 0, 0, 0, n);
467 assert!(eri_0000 > 0.0, "(11|11) = {eri_0000}");
468 }
469
470 #[test]
471 fn test_eri_symmetry() {
472 let basis = build_sto3g_basis(&[1, 1], &[[0.0, 0.0, 0.0], [0.0, 0.0, 0.74]]);
473 let eris = compute_eris(&basis);
474 let n = basis.n_basis();
475 assert_eq!(get_eri(&eris, 0, 1, 0, 1, n), get_eri(&eris, 1, 0, 1, 0, n));
477 assert_eq!(get_eri(&eris, 0, 1, 0, 1, n), get_eri(&eris, 0, 1, 1, 0, n));
478 }
479}