Skip to main content

sci_form/gpu/
one_electron_gpu.rs

1//! GPU-accelerated one-electron matrix construction.
2//!
3//! Computes overlap (S), kinetic (T), and nuclear attraction (V) matrices
4//! on the GPU in a single dispatch. Each thread handles one (μ,ν) pair
5//! and outputs three matrix values.
6//!
7//! For STO-3G basis sets, the O(N²) contraction with max 3×3 primitive
8//! pairs per element provides ~9x work per matrix element.
9
10use super::context::{
11    ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
12};
13use crate::scf::basis::BasisSet;
14use crate::scf::kinetic_matrix::build_kinetic_matrix;
15use crate::scf::nuclear_matrix::build_nuclear_matrix;
16use crate::scf::overlap_matrix::build_overlap_matrix;
17
18/// Maximum primitives per basis function (STO-3G = 3).
19const MAX_PRIMITIVES: usize = 3;
20
21/// Minimum basis functions to justify GPU dispatch.
22const GPU_DISPATCH_THRESHOLD: usize = 6;
23
24/// Result of GPU one-electron matrix computation.
25pub struct OneElectronResult {
26    /// Overlap matrix S, row-major flat (N×N).
27    pub overlap: Vec<f64>,
28    /// Kinetic energy matrix T, row-major flat (N×N).
29    pub kinetic: Vec<f64>,
30    /// Nuclear attraction matrix V, row-major flat (N×N).
31    pub nuclear: Vec<f64>,
32    /// Number of basis functions.
33    pub n_basis: usize,
34    /// Whether the matrices were actually produced by a GPU dispatch.
35    pub used_gpu: bool,
36    /// Backend that produced the result.
37    pub backend: String,
38    /// Human-readable execution note.
39    pub note: String,
40}
41
42fn flatten_matrix_row_major(matrix: &nalgebra::DMatrix<f64>) -> Vec<f64> {
43    (0..matrix.nrows())
44        .flat_map(|i| (0..matrix.ncols()).map(move |j| matrix[(i, j)]))
45        .collect()
46}
47
48fn compute_one_electron_cpu_exact(
49    basis: &BasisSet,
50    elements: &[u8],
51    positions_bohr: &[[f64; 3]],
52    note: impl Into<String>,
53) -> OneElectronResult {
54    let overlap = build_overlap_matrix(basis);
55    let kinetic = build_kinetic_matrix(basis);
56    let nuclear = build_nuclear_matrix(basis, elements, positions_bohr);
57
58    OneElectronResult {
59        overlap: flatten_matrix_row_major(&overlap),
60        kinetic: flatten_matrix_row_major(&kinetic),
61        nuclear: flatten_matrix_row_major(&nuclear),
62        n_basis: basis.n_basis,
63        used_gpu: false,
64        backend: "CPU-exact".to_string(),
65        note: note.into(),
66    }
67}
68
69fn basis_supports_exact_gpu_kernel(basis: &BasisSet) -> bool {
70    basis.functions.iter().all(|bf| bf.l_total == 0)
71}
72
73/// Pack basis set and nuclear data for the one-electron GPU shader.
74///
75/// Basis buffer: per function (32 bytes):
76///   center [f32;3], lx: u32, ly: u32, lz: u32, n_prims: u32, _pad: u32
77///
78/// Primitives buffer: per function (24 bytes):
79///   3 × (alpha: f32, norm*coeff: f32)
80///
81/// Atoms buffer: per atom (16 bytes):
82///   position [f32;3], atomic_number: f32
83fn pack_one_electron_data(
84    basis: &BasisSet,
85    elements: &[u8],
86    positions_bohr: &[[f64; 3]],
87) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
88    let mut basis_bytes = Vec::with_capacity(basis.n_basis * 32);
89    let mut prim_bytes = Vec::with_capacity(basis.n_basis * MAX_PRIMITIVES * 8);
90
91    for bf in &basis.functions {
92        basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
93        basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
94        basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
95        basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
96        basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
97        basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
98        basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
99        basis_bytes.extend_from_slice(&0u32.to_ne_bytes());
100
101        for i in 0..MAX_PRIMITIVES {
102            if i < bf.primitives.len() {
103                let norm = crate::scf::basis::BasisFunction::normalization(
104                    bf.primitives[i].alpha,
105                    bf.angular[0],
106                    bf.angular[1],
107                    bf.angular[2],
108                );
109                prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
110                prim_bytes.extend_from_slice(
111                    &((bf.primitives[i].coefficient * norm) as f32).to_ne_bytes(),
112                );
113            } else {
114                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
115                prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
116            }
117        }
118    }
119
120    // Pack atom positions and charges
121    let mut atom_bytes = Vec::with_capacity(elements.len() * 16);
122    for (i, &z) in elements.iter().enumerate() {
123        atom_bytes.extend_from_slice(&(positions_bohr[i][0] as f32).to_ne_bytes());
124        atom_bytes.extend_from_slice(&(positions_bohr[i][1] as f32).to_ne_bytes());
125        atom_bytes.extend_from_slice(&(positions_bohr[i][2] as f32).to_ne_bytes());
126        atom_bytes.extend_from_slice(&(z as f32).to_ne_bytes());
127    }
128
129    (basis_bytes, prim_bytes, atom_bytes)
130}
131
132/// Compute one-electron matrices (S, T, V) on the GPU.
133///
134/// Returns overlap, kinetic, and nuclear matrices packed as flat f64 arrays.
135pub fn compute_one_electron_gpu(
136    ctx: &GpuContext,
137    basis: &BasisSet,
138    elements: &[u8],
139    positions_bohr: &[[f64; 3]],
140) -> Result<OneElectronResult, String> {
141    let n = basis.n_basis;
142    let n_atoms = elements.len();
143
144    if n < GPU_DISPATCH_THRESHOLD {
145        return Err("Basis too small for GPU dispatch".to_string());
146    }
147
148    if !basis_supports_exact_gpu_kernel(basis) {
149        return Ok(compute_one_electron_cpu_exact(
150            basis,
151            elements,
152            positions_bohr,
153            "Fell back to exact CPU one-electron builders because the current WGSL kernel only supports pure s-type basis functions.",
154        ));
155    }
156
157    let (basis_bytes, prim_bytes, atom_bytes) =
158        pack_one_electron_data(basis, elements, positions_bohr);
159
160    // Params: n_basis, n_atoms, pad, pad (16 bytes)
161    let mut params = Vec::with_capacity(16);
162    params.extend_from_slice(&(n as u32).to_ne_bytes());
163    params.extend_from_slice(&(n_atoms as u32).to_ne_bytes());
164    params.extend_from_slice(&0u32.to_ne_bytes());
165    params.extend_from_slice(&0u32.to_ne_bytes());
166
167    // Output: 3 matrices (S, T, V) packed sequentially, each N×N f32
168    let output_size = 3 * n * n;
169    let output_seed = vec![0.0f32; output_size];
170    let output_bytes: Vec<u8> = output_seed.iter().flat_map(|v| v.to_ne_bytes()).collect();
171
172    let wg_size = 16u32;
173    let wg_x = (n as u32).div_ceil(wg_size);
174    let wg_y = wg_x;
175
176    let descriptor = ComputeDispatchDescriptor {
177        label: "one-electron matrices".to_string(),
178        shader_source: ONE_ELECTRON_SHADER.to_string(),
179        entry_point: "main".to_string(),
180        workgroup_count: [wg_x, wg_y, 1],
181        bindings: vec![
182            ComputeBindingDescriptor {
183                label: "basis".to_string(),
184                kind: ComputeBindingKind::StorageReadOnly,
185                bytes: basis_bytes,
186            },
187            ComputeBindingDescriptor {
188                label: "primitives".to_string(),
189                kind: ComputeBindingKind::StorageReadOnly,
190                bytes: prim_bytes,
191            },
192            ComputeBindingDescriptor {
193                label: "atoms".to_string(),
194                kind: ComputeBindingKind::StorageReadOnly,
195                bytes: atom_bytes,
196            },
197            ComputeBindingDescriptor {
198                label: "params".to_string(),
199                kind: ComputeBindingKind::Uniform,
200                bytes: params,
201            },
202            ComputeBindingDescriptor {
203                label: "output".to_string(),
204                kind: ComputeBindingKind::StorageReadWrite,
205                bytes: output_bytes,
206            },
207        ],
208    };
209
210    let mut result = ctx.run_compute(&descriptor)?;
211    let bytes = result
212        .outputs
213        .pop()
214        .ok_or("No output from one-electron kernel")?;
215
216    if bytes.len() != output_size * 4 {
217        return Err(format!(
218            "Output size mismatch: expected {}, got {}",
219            output_size * 4,
220            bytes.len()
221        ));
222    }
223
224    let all_f64: Vec<f64> = bytes
225        .chunks_exact(4)
226        .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]) as f64)
227        .collect();
228
229    let n2 = n * n;
230    Ok(OneElectronResult {
231        overlap: all_f64[..n2].to_vec(),
232        kinetic: all_f64[n2..2 * n2].to_vec(),
233        nuclear: all_f64[2 * n2..].to_vec(),
234        n_basis: n,
235        used_gpu: true,
236        backend: ctx.capabilities.backend.clone(),
237        note: "Executed WGSL one-electron kernel for a pure s-type basis.".to_string(),
238    })
239}
240
241/// WGSL compute shader for one-electron matrices (S, T, V).
242///
243/// Each thread computes S(μ,ν), T(μ,ν), and V(μ,ν) for one basis pair.
244/// Exploits symmetric property to only compute upper triangle + reflect.
245///
246/// For s-type functions:
247///   S = (π/p)^{3/2} · exp(-μ·|AB|²)
248///   T = β(2l+3)S - 2β² S(a,b+2) [simplified for s-type]
249///   V = -2π/p · exp(-μ·|AB|²) · Σ_C Z_C · F₀(p·|PC|²)
250pub const ONE_ELECTRON_SHADER: &str = r#"
251struct BasisFunc {
252    cx: f32, cy: f32, cz: f32,
253    lx: u32, ly: u32, lz: u32,
254    n_prims: u32, _pad: u32,
255};
256
257struct Atom {
258    x: f32, y: f32, z: f32,
259    charge: f32,
260};
261
262struct Params {
263    n_basis: u32,
264    n_atoms: u32,
265    _pad0: u32,
266    _pad1: u32,
267};
268
269@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
270@group(0) @binding(1) var<storage, read> primitives: array<vec2<f32>>;
271@group(0) @binding(2) var<storage, read> atoms: array<Atom>;
272@group(0) @binding(3) var<uniform> params: Params;
273@group(0) @binding(4) var<storage, read_write> output: array<f32>;
274// Output layout: [S_00..S_{nn}] [T_00..T_{nn}] [V_00..V_{nn}]
275
276fn boys_f0(x: f32) -> f32 {
277    if (x < 1e-7) {
278        return 1.0;
279    }
280    if (x > 30.0) {
281        return 0.8862269 / sqrt(x);
282    }
283    var sum: f32 = 1.0;
284    var term: f32 = 1.0;
285    for (var k: u32 = 1u; k < 50u; k = k + 1u) {
286        term *= 2.0 * x / f32(2u * k + 1u);
287        sum += term;
288        if (abs(term) < 1e-6 * abs(sum)) {
289            break;
290        }
291    }
292    return exp(-x) * sum;
293}
294
295@compute @workgroup_size(16, 16, 1)
296fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
297    let mu = gid.x;
298    let nu = gid.y;
299    let n = params.n_basis;
300    let n2 = n * n;
301
302    if (mu >= n || nu >= n) {
303        return;
304    }
305
306    // Only compute upper triangle (mu <= nu), then mirror
307    let compute_mu = min(mu, nu);
308    let compute_nu = max(mu, nu);
309
310    let bf_a = basis[compute_mu];
311    let bf_b = basis[compute_nu];
312
313    var s_total: f32 = 0.0;
314    var t_total: f32 = 0.0;
315    var v_total: f32 = 0.0;
316
317    // Contract over primitives
318    for (var pa: u32 = 0u; pa < bf_a.n_prims; pa = pa + 1u) {
319        let prim_a = primitives[compute_mu * 3u + pa];
320        let alpha = prim_a.x;
321        let ca = prim_a.y;
322
323        for (var pb: u32 = 0u; pb < bf_b.n_prims; pb = pb + 1u) {
324            let prim_b = primitives[compute_nu * 3u + pb];
325            let beta = prim_b.x;
326            let cb = prim_b.y;
327
328            let p = alpha + beta;
329            let mu_ab = alpha * beta / p;
330            let ab2 = (bf_a.cx - bf_b.cx) * (bf_a.cx - bf_b.cx)
331                     + (bf_a.cy - bf_b.cy) * (bf_a.cy - bf_b.cy)
332                     + (bf_a.cz - bf_b.cz) * (bf_a.cz - bf_b.cz);
333            let k_ab = exp(-mu_ab * ab2);
334
335            // Overlap: S = (π/p)^{3/2} · K_ab
336            let pi = 3.14159265;
337            let s_prim = pow(pi / p, 1.5) * k_ab;
338            s_total += ca * cb * s_prim;
339
340            // Kinetic: T = μ(3 - 2μ·|AB|²) · S / p  [s-type simplification]
341            let t_prim = mu_ab * (3.0 - 2.0 * mu_ab * ab2) * s_prim;
342            t_total += ca * cb * t_prim;
343
344            // Nuclear attraction: V = -Σ_C Z_C · 2π/p · K_ab · F₀(p·|PC|²)
345            let px = (alpha * bf_a.cx + beta * bf_b.cx) / p;
346            let py = (alpha * bf_a.cy + beta * bf_b.cy) / p;
347            let pz = (alpha * bf_a.cz + beta * bf_b.cz) / p;
348
349            for (var c: u32 = 0u; c < params.n_atoms; c = c + 1u) {
350                let atom = atoms[c];
351                let pc2 = (px - atom.x) * (px - atom.x)
352                        + (py - atom.y) * (py - atom.y)
353                        + (pz - atom.z) * (pz - atom.z);
354                let v_prim = -atom.charge * 2.0 * pi / p * k_ab * boys_f0(p * pc2);
355                v_total += ca * cb * v_prim;
356            }
357        }
358    }
359
360    // Write to output: S at offset 0, T at offset n², V at offset 2n²
361    output[mu * n + nu] = s_total;
362    output[n2 + mu * n + nu] = t_total;
363    output[2u * n2 + mu * n + nu] = v_total;
364}
365"#;
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_pack_one_electron() {
373        let basis =
374            crate::scf::basis::BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
375        let (basis_bytes, prim_bytes, atom_bytes) =
376            pack_one_electron_data(&basis, &[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
377        assert_eq!(basis_bytes.len(), basis.n_basis * 32);
378        assert_eq!(prim_bytes.len(), basis.n_basis * MAX_PRIMITIVES * 8);
379        assert_eq!(atom_bytes.len(), 2 * 16);
380    }
381
382    #[test]
383    fn test_gpu_threshold() {
384        let ctx = GpuContext::cpu_fallback();
385        let basis = crate::scf::basis::BasisSet::sto3g(&[1], &[[0.0, 0.0, 0.0]]);
386        let result = compute_one_electron_gpu(&ctx, &basis, &[1], &[[0.0, 0.0, 0.0]]);
387        assert!(result.is_err());
388    }
389
390    #[test]
391    fn test_mixed_angular_momentum_falls_back_to_exact_cpu() {
392        let elements = [8u8, 1, 1];
393        let positions = [
394            [0.0, 0.0, 0.117],
395            [0.0, 0.757, -0.469],
396            [0.0, -0.757, -0.469],
397        ];
398        let basis = crate::scf::basis::BasisSet::sto3g(&elements, &positions);
399        let ctx = GpuContext::cpu_fallback();
400
401        let result = compute_one_electron_gpu(&ctx, &basis, &elements, &positions)
402            .expect("mixed-angular basis should fall back to CPU");
403
404        let overlap = build_overlap_matrix(&basis);
405        let overlap_flat = flatten_matrix_row_major(&overlap);
406        assert!(!result.used_gpu);
407        assert_eq!(result.backend, "CPU-exact");
408        assert!(result.note.contains("pure s-type"));
409        assert_eq!(result.overlap.len(), overlap_flat.len());
410        for (lhs, rhs) in result.overlap.iter().zip(overlap_flat.iter()) {
411            assert!((lhs - rhs).abs() < 1e-12);
412        }
413    }
414}