Skip to main content

sci_form/gpu/
fock_build_gpu.rs

1//! GPU-accelerated Fock matrix construction.
2//!
3//! Computes the two-electron part G(P) of the Fock matrix on the GPU:
4//!   G(μν) = Σ_{λσ} P_{λσ} [(μν|λσ) - 0.5·(μλ|νσ)]
5//!   F = H_core + G(P)
6//!
7//! Each GPU thread handles one (μ,ν) matrix element, reading the full
8//! density matrix and ERI tensor to compute the Coulomb-exchange contribution.
9
10use super::context::{
11    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
12    ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13    UniformValue,
14};
15
16/// Minimum matrix dimension to justify GPU dispatch.
17const GPU_DISPATCH_THRESHOLD: usize = 4;
18
19/// Build the Fock matrix on the GPU.
20///
21/// Inputs:
22/// - `h_core`: Core Hamiltonian (N×N), row-major flattened f64
23/// - `density`: Density matrix P (N×N), row-major flattened f64
24/// - `eris`: Two-electron integrals (N⁴), flattened f64
25/// - `n_basis`: Number of basis functions
26///
27/// Returns the Fock matrix F = H + G(P) as flattened f64 (N×N).
28pub fn build_fock_gpu(
29    ctx: &GpuContext,
30    h_core: &[f64],
31    density: &[f64],
32    eris: &[f64],
33    n_basis: usize,
34) -> Result<Vec<f64>, String> {
35    if n_basis < GPU_DISPATCH_THRESHOLD {
36        return Err("Basis too small for GPU dispatch".to_string());
37    }
38
39    let n2 = n_basis * n_basis;
40    if h_core.len() != n2 || density.len() != n2 {
41        return Err("Matrix dimension mismatch".to_string());
42    }
43
44    // Pack matrices as f32
45    let h_core_f32: Vec<f32> = h_core.iter().map(|v| *v as f32).collect();
46    let density_f32: Vec<f32> = density.iter().map(|v| *v as f32).collect();
47    let eris_f32: Vec<f32> = eris.iter().map(|v| *v as f32).collect();
48
49    // Params: [n_basis: u32, pad: u32, pad: u32, pad: u32] = 16 bytes
50    let params = pack_uniform_values(&[
51        UniformValue::U32(n_basis as u32),
52        UniformValue::U32(0),
53        UniformValue::U32(0),
54        UniformValue::U32(0),
55    ]);
56
57    // Output: Fock matrix (N×N f32)
58    let output_seed = vec![0.0f32; n2];
59
60    let wg_size = 16u32;
61    let wg_x = ceil_div_u32(n_basis, wg_size);
62    let wg_y = wg_x;
63
64    let descriptor = ComputeDispatchDescriptor {
65        label: "fock matrix build".to_string(),
66        shader_source: FOCK_BUILD_SHADER.to_string(),
67        entry_point: "main".to_string(),
68        workgroup_count: [wg_x, wg_y, 1],
69        bindings: vec![
70            ComputeBindingDescriptor {
71                label: "h_core".to_string(),
72                kind: ComputeBindingKind::StorageReadOnly,
73                bytes: f32_slice_to_bytes(&h_core_f32),
74            },
75            ComputeBindingDescriptor {
76                label: "density".to_string(),
77                kind: ComputeBindingKind::StorageReadOnly,
78                bytes: f32_slice_to_bytes(&density_f32),
79            },
80            ComputeBindingDescriptor {
81                label: "eris".to_string(),
82                kind: ComputeBindingKind::StorageReadOnly,
83                bytes: f32_slice_to_bytes(&eris_f32),
84            },
85            ComputeBindingDescriptor {
86                label: "params".to_string(),
87                kind: ComputeBindingKind::Uniform,
88                bytes: params,
89            },
90            ComputeBindingDescriptor {
91                label: "output".to_string(),
92                kind: ComputeBindingKind::StorageReadWrite,
93                bytes: f32_slice_to_bytes(&output_seed),
94            },
95        ],
96    };
97
98    let mut result = ctx.run_compute(&descriptor)?;
99    let bytes = result
100        .outputs
101        .pop()
102        .ok_or("No output from Fock build kernel")?;
103
104    if bytes.len() != n2 * 4 {
105        return Err(format!(
106            "Fock output size mismatch: expected {}, got {}",
107            n2 * 4,
108            bytes.len()
109        ));
110    }
111
112    let fock = bytes_to_f64_vec_from_f32(&bytes);
113
114    Ok(fock)
115}
116
117/// WGSL compute shader for Fock matrix construction.
118///
119/// F(μ,ν) = H_core(μ,ν) + Σ_{λσ} P(λ,σ) [(μν|λσ) - 0.5·(μλ|νσ)]
120///
121/// Workgroup size: (16, 16, 1) = 256 threads.
122/// Each thread computes one F(μ,ν) element.
123pub const FOCK_BUILD_SHADER: &str = r#"
124struct Params {
125    n_basis: u32,
126    _pad0: u32,
127    _pad1: u32,
128    _pad2: u32,
129};
130
131@group(0) @binding(0) var<storage, read> h_core: array<f32>;
132@group(0) @binding(1) var<storage, read> density: array<f32>;
133@group(0) @binding(2) var<storage, read> eris: array<f32>;
134@group(0) @binding(3) var<uniform> params: Params;
135@group(0) @binding(4) var<storage, read_write> fock: array<f32>;
136
137@compute @workgroup_size(16, 16, 1)
138fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
139    let mu = gid.x;
140    let nu = gid.y;
141    let n = params.n_basis;
142
143    if (mu >= n || nu >= n) {
144        return;
145    }
146
147    let n2 = n * n;
148    var g_mn: f32 = 0.0;
149
150    // G(μ,ν) = Σ_{λσ} P(λ,σ) · [(μν|λσ) - 0.5·(μλ|νσ)]
151    for (var lam: u32 = 0u; lam < n; lam = lam + 1u) {
152        for (var sig: u32 = 0u; sig < n; sig = sig + 1u) {
153            let p_ls = density[lam * n + sig];
154
155            // Coulomb: (μν|λσ)
156            let j_idx = mu * n * n2 + nu * n2 + lam * n + sig;
157            let j_val = eris[j_idx];
158
159            // Exchange: (μλ|νσ)
160            let k_idx = mu * n * n2 + lam * n2 + nu * n + sig;
161            let k_val = eris[k_idx];
162
163            g_mn += p_ls * (j_val - 0.5 * k_val);
164        }
165    }
166
167    fock[mu * n + nu] = h_core[mu * n + nu] + g_mn;
168}
169"#;
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174
175    #[test]
176    fn test_f32_slice_to_bytes() {
177        let data = vec![1.0f32, 2.0, 3.0];
178        let bytes = f32_slice_to_bytes(&data);
179        assert_eq!(bytes.len(), 12);
180    }
181
182    #[test]
183    fn test_gpu_threshold() {
184        let ctx = GpuContext::cpu_fallback();
185        let n = 2;
186        let h = vec![0.0f64; n * n];
187        let d = vec![0.0f64; n * n];
188        let e = vec![0.0f64; n * n * n * n];
189        let result = build_fock_gpu(&ctx, &h, &d, &e, n);
190        assert!(result.is_err());
191    }
192
193    #[test]
194    fn test_dimension_mismatch() {
195        let ctx = GpuContext::cpu_fallback();
196        // n=5 passes threshold but matrices have wrong sizes
197        let result = build_fock_gpu(&ctx, &[0.0; 25], &[0.0; 16], &[0.0; 625], 5);
198        assert!(result.is_err());
199        assert!(result.unwrap_err().contains("mismatch"));
200    }
201}