Skip to main content

sci_form/gpu/
two_electron_gpu.rs

1//! GPU-accelerated two-electron repulsion integrals (ERIs).
2//!
3//! Offloads the O(N⁴) ERI computation to the GPU using a WGSL compute shader.
4//! Each workgroup thread computes one unique quartet (i,j,k,l) of contracted ERIs,
5//! exploiting 8-fold permutation symmetry to reduce work.
6//!
7//! For STO-3G basis sets (max 3 primitives/function), this provides significant
8//! speedup over CPU for molecules with > ~10 basis functions.
9
10use super::context::{
11    ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
12};
13use crate::scf::basis::BasisSet;
14use crate::scf::two_electron::TwoElectronIntegrals;
15
16/// Maximum primitives per basis function (STO-3G = 3).
17const MAX_PRIMITIVES: usize = 3;
18
19/// Minimum basis functions to justify GPU dispatch overhead.
20const GPU_DISPATCH_THRESHOLD: usize = 4;
21
22/// Pack basis set data for the ERI GPU shader.
23///
24/// Basis buffer layout per function (32 bytes = 8 × f32):
25///   center: [f32; 3], lx: u32, ly: u32, lz: u32, n_primitives: u32, _pad: u32
26///
27/// Primitives buffer layout per function (24 bytes = 6 × f32):
28///   3 × (alpha: f32, norm_coeff: f32), padded with zeros if fewer primitives
29fn pack_basis_eri(basis: &BasisSet) -> (Vec<u8>, Vec<u8>) {
30    let mut basis_bytes = Vec::with_capacity(basis.n_basis * 32);
31    let mut prim_bytes = Vec::with_capacity(basis.n_basis * 24);
32
33    for bf in &basis.functions {
34        // center xyz (12 bytes)
35        basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
36        basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
37        basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
38        // lx, ly, lz (12 bytes)
39        basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
40        basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
41        basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
42        // n_primitives + padding (8 bytes)
43        basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
44        basis_bytes.extend_from_slice(&0u32.to_ne_bytes());
45
46        // Primitives: 3 × (alpha, norm*coeff)
47        for i in 0..MAX_PRIMITIVES {
48            if i < bf.primitives.len() {
49                let norm = crate::scf::basis::BasisFunction::normalization(
50                    bf.primitives[i].alpha,
51                    bf.angular[0],
52                    bf.angular[1],
53                    bf.angular[2],
54                );
55                prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
56                prim_bytes.extend_from_slice(
57                    &((bf.primitives[i].coefficient * norm) as f32).to_ne_bytes(),
58                );
59            } else {
60                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
61                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
62            }
63        }
64    }
65
66    (basis_bytes, prim_bytes)
67}
68
69/// Enumerate unique quartets with 8-fold symmetry: i≥j, k≥l, ij≥kl.
70/// Returns (n_quartets, quartet_indices) where quartet_indices is a flat
71/// array of [i, j, k, l] u32 tuples.
72fn enumerate_unique_quartets(n: usize) -> (usize, Vec<u8>) {
73    let mut quartets = Vec::new();
74    let mut count = 0usize;
75
76    for i in 0..n {
77        for j in 0..=i {
78            let ij = i * n + j;
79            for k in 0..n {
80                for l in 0..=k {
81                    let kl = k * n + l;
82                    if ij < kl {
83                        continue;
84                    }
85                    quartets.extend_from_slice(&(i as u32).to_ne_bytes());
86                    quartets.extend_from_slice(&(j as u32).to_ne_bytes());
87                    quartets.extend_from_slice(&(k as u32).to_ne_bytes());
88                    quartets.extend_from_slice(&(l as u32).to_ne_bytes());
89                    count += 1;
90                }
91            }
92        }
93    }
94
95    (count, quartets)
96}
97
98/// Compute two-electron integrals on the GPU.
99///
100/// Returns `TwoElectronIntegrals` with the full symmetrized tensor,
101/// or an error if GPU dispatch fails. Falls back to CPU if the
102/// basis is too small to benefit from GPU.
103pub fn compute_eris_gpu(
104    ctx: &GpuContext,
105    basis: &BasisSet,
106) -> Result<TwoElectronIntegrals, String> {
107    let n = basis.n_basis;
108
109    if n < GPU_DISPATCH_THRESHOLD {
110        return Err("Basis too small for GPU dispatch".to_string());
111    }
112
113    // Check memory: output is n⁴ × 4 bytes (f32)
114    let output_size = (n * n * n * n * 4) as u64;
115    if output_size > ctx.capabilities.max_storage_buffer_size {
116        return Err(format!(
117            "ERI tensor ({} bytes) exceeds GPU storage limit ({} bytes)",
118            output_size, ctx.capabilities.max_storage_buffer_size
119        ));
120    }
121
122    let (basis_bytes, prim_bytes) = pack_basis_eri(basis);
123    let (n_quartets, quartet_bytes) = enumerate_unique_quartets(n);
124
125    // Params uniform: [n_basis: u32, n_quartets: u32, pad: u32, pad: u32] = 16 bytes
126    let mut params = Vec::with_capacity(16);
127    params.extend_from_slice(&(n as u32).to_ne_bytes());
128    params.extend_from_slice(&(n_quartets as u32).to_ne_bytes());
129    params.extend_from_slice(&0u32.to_ne_bytes());
130    params.extend_from_slice(&0u32.to_ne_bytes());
131
132    // Output buffer (f32, n⁴ elements)
133    let output_seed = vec![0u8; n * n * n * n * 4];
134
135    let workgroup_count = [(n_quartets as u32).div_ceil(64), 1, 1];
136
137    let descriptor = ComputeDispatchDescriptor {
138        label: "two-electron ERI".to_string(),
139        shader_source: TWO_ELECTRON_SHADER.to_string(),
140        entry_point: "main".to_string(),
141        workgroup_count,
142        bindings: vec![
143            ComputeBindingDescriptor {
144                label: "basis".to_string(),
145                kind: ComputeBindingKind::StorageReadOnly,
146                bytes: basis_bytes,
147            },
148            ComputeBindingDescriptor {
149                label: "primitives".to_string(),
150                kind: ComputeBindingKind::StorageReadOnly,
151                bytes: prim_bytes,
152            },
153            ComputeBindingDescriptor {
154                label: "quartets".to_string(),
155                kind: ComputeBindingKind::StorageReadOnly,
156                bytes: quartet_bytes,
157            },
158            ComputeBindingDescriptor {
159                label: "params".to_string(),
160                kind: ComputeBindingKind::Uniform,
161                bytes: params,
162            },
163            ComputeBindingDescriptor {
164                label: "output".to_string(),
165                kind: ComputeBindingKind::StorageReadWrite,
166                bytes: output_seed,
167            },
168        ],
169    };
170
171    let mut result = ctx.run_compute(&descriptor)?;
172    let bytes = result.outputs.pop().ok_or("No output from ERI kernel")?;
173
174    // Convert f32 output to f64
175    if bytes.len() != n * n * n * n * 4 {
176        return Err(format!(
177            "ERI output size mismatch: expected {}, got {}",
178            n * n * n * n * 4,
179            bytes.len()
180        ));
181    }
182
183    let data: Vec<f64> = bytes
184        .chunks_exact(4)
185        .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]) as f64)
186        .collect();
187
188    Ok(TwoElectronIntegrals::from_raw(data, n))
189}
190
191/// WGSL compute shader for two-electron repulsion integrals.
192///
193/// Each thread processes one unique quartet (i,j,k,l) and writes
194/// all 8 symmetry-related elements to the output tensor.
195///
196/// Computes contracted ERIs: (μν|λσ) = Σ_{pqrs} Nₚcₚ·Nqcq·Nrcr·Nscs · [pq|rs]
197/// where [pq|rs] = 2π^{5/2}/(p·q·√(p+q)) · Kab·Kcd · F₀(αpq·|PQ|²)
198pub const TWO_ELECTRON_SHADER: &str = r#"
199struct BasisFunc {
200    cx: f32, cy: f32, cz: f32,
201    lx: u32, ly: u32, lz: u32,
202    n_prims: u32, _pad: u32,
203};
204
205struct Params {
206    n_basis: u32,
207    n_quartets: u32,
208    _pad0: u32,
209    _pad1: u32,
210};
211
212@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
213@group(0) @binding(1) var<storage, read> primitives: array<vec2<f32>>;  // (alpha, norm*coeff)
214@group(0) @binding(2) var<storage, read> quartets: array<vec4<u32>>;
215@group(0) @binding(3) var<uniform> params: Params;
216@group(0) @binding(4) var<storage, read_write> output: array<f32>;
217
218// Boys function F₀(x) via series expansion
219fn boys_f0(x: f32) -> f32 {
220    if (x < 1e-7) {
221        return 1.0;
222    }
223    if (x > 30.0) {
224        return 0.8862269 / sqrt(x);  // √π / (2√x)
225    }
226    var sum: f32 = 1.0;
227    var term: f32 = 1.0;
228    for (var k: u32 = 1u; k < 50u; k = k + 1u) {
229        term *= 2.0 * x / f32(2u * k + 1u);
230        sum += term;
231        if (abs(term) < 1e-6 * abs(sum)) {
232            break;
233        }
234    }
235    return exp(-x) * sum;
236}
237
238@compute @workgroup_size(64)
239fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
240    let idx = gid.x;
241    if (idx >= params.n_quartets) {
242        return;
243    }
244
245    let q = quartets[idx];
246    let i = q.x;
247    let j = q.y;
248    let k = q.z;
249    let l = q.w;
250
251    let bf_i = basis[i];
252    let bf_j = basis[j];
253    let bf_k = basis[k];
254    let bf_l = basis[l];
255
256    var eri: f32 = 0.0;
257
258    // Contract over primitives (max 3 per function for STO-3G)
259    for (var pi: u32 = 0u; pi < bf_i.n_prims; pi = pi + 1u) {
260        let prim_a = primitives[i * 3u + pi];
261        let alpha = prim_a.x;
262        let ca = prim_a.y;
263
264        for (var pj: u32 = 0u; pj < bf_j.n_prims; pj = pj + 1u) {
265            let prim_b = primitives[j * 3u + pj];
266            let beta = prim_b.x;
267            let cb = prim_b.y;
268
269            let p = alpha + beta;
270            let mu_ab = alpha * beta / p;
271            let ab2 = (bf_i.cx - bf_j.cx) * (bf_i.cx - bf_j.cx)
272                     + (bf_i.cy - bf_j.cy) * (bf_i.cy - bf_j.cy)
273                     + (bf_i.cz - bf_j.cz) * (bf_i.cz - bf_j.cz);
274            let k_ab = exp(-mu_ab * ab2);
275
276            let px = (alpha * bf_i.cx + beta * bf_j.cx) / p;
277            let py = (alpha * bf_i.cy + beta * bf_j.cy) / p;
278            let pz = (alpha * bf_i.cz + beta * bf_j.cz) / p;
279
280            for (var pk: u32 = 0u; pk < bf_k.n_prims; pk = pk + 1u) {
281                let prim_c = primitives[k * 3u + pk];
282                let gamma = prim_c.x;
283                let cc = prim_c.y;
284
285                for (var pl: u32 = 0u; pl < bf_l.n_prims; pl = pl + 1u) {
286                    let prim_d = primitives[l * 3u + pl];
287                    let delta = prim_d.x;
288                    let cd = prim_d.y;
289
290                    let qq = gamma + delta;
291                    let mu_cd = gamma * delta / qq;
292                    let cd2 = (bf_k.cx - bf_l.cx) * (bf_k.cx - bf_l.cx)
293                             + (bf_k.cy - bf_l.cy) * (bf_k.cy - bf_l.cy)
294                             + (bf_k.cz - bf_l.cz) * (bf_k.cz - bf_l.cz);
295                    let k_cd = exp(-mu_cd * cd2);
296
297                    let qx = (gamma * bf_k.cx + delta * bf_l.cx) / qq;
298                    let qy = (gamma * bf_k.cy + delta * bf_l.cy) / qq;
299                    let qz = (gamma * bf_k.cz + delta * bf_l.cz) / qq;
300
301                    let pq2 = (px - qx) * (px - qx)
302                            + (py - qy) * (py - qy)
303                            + (pz - qz) * (pz - qz);
304                    let alpha_pq = p * qq / (p + qq);
305
306                    // prefactor = 2π^{5/2} / (p · q · √(p+q))
307                    let prefactor = 2.0 * pow(3.14159265, 2.5) / (p * qq * sqrt(p + qq));
308
309                    eri += ca * cb * cc * cd * prefactor * k_ab * k_cd * boys_f0(alpha_pq * pq2);
310                }
311            }
312        }
313    }
314
315    let n = params.n_basis;
316    let n2 = n * n;
317
318    // Store with 8-fold symmetry
319    output[i * n * n2 + j * n2 + k * n + l] = eri;
320    output[j * n * n2 + i * n2 + k * n + l] = eri;
321    output[i * n * n2 + j * n2 + l * n + k] = eri;
322    output[j * n * n2 + i * n2 + l * n + k] = eri;
323    output[k * n * n2 + l * n2 + i * n + j] = eri;
324    output[l * n * n2 + k * n2 + i * n + j] = eri;
325    output[k * n * n2 + l * n2 + j * n + i] = eri;
326    output[l * n * n2 + k * n2 + j * n + i] = eri;
327}
328"#;
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333
334    #[test]
335    fn test_enumerate_quartets_h2() {
336        // H₂ STO-3G: 2 basis functions → unique quartets with i≥j, k≥l, ij≥kl
337        let (count, bytes) = enumerate_unique_quartets(2);
338        assert!(count > 0);
339        assert_eq!(bytes.len(), count * 16); // 4 × u32 per quartet
340    }
341
342    #[test]
343    fn test_enumerate_quartets_single() {
344        // 1 basis function → exactly 1 quartet: (0,0,0,0)
345        let (count, _) = enumerate_unique_quartets(1);
346        assert_eq!(count, 1);
347    }
348
349    #[test]
350    fn test_pack_basis_eri() {
351        let basis =
352            crate::scf::basis::BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
353        let (basis_bytes, prim_bytes) = pack_basis_eri(&basis);
354        // Basis: 32 bytes per function
355        assert_eq!(basis_bytes.len(), basis.n_basis * 32);
356        // Primitives: 24 bytes per function (3 × 8)
357        assert_eq!(prim_bytes.len(), basis.n_basis * 24);
358    }
359
360    #[test]
361    fn test_gpu_threshold() {
362        let ctx = GpuContext::cpu_fallback();
363        let basis = crate::scf::basis::BasisSet::sto3g(&[1], &[[0.0, 0.0, 0.0]]);
364        let result = compute_eris_gpu(&ctx, &basis);
365        assert!(result.is_err()); // Too small for GPU
366    }
367}