sci_form/gpu/
fock_build_gpu.rs1use super::context::{
11 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
12 ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13 UniformValue,
14};
15
16const GPU_DISPATCH_THRESHOLD: usize = 4;
18
19pub fn build_fock_gpu(
29 ctx: &GpuContext,
30 h_core: &[f64],
31 density: &[f64],
32 eris: &[f64],
33 n_basis: usize,
34) -> Result<Vec<f64>, String> {
35 if n_basis < GPU_DISPATCH_THRESHOLD {
36 return Err("Basis too small for GPU dispatch".to_string());
37 }
38
39 let n2 = n_basis * n_basis;
40 if h_core.len() != n2 || density.len() != n2 {
41 return Err("Matrix dimension mismatch".to_string());
42 }
43
44 let h_core_f32: Vec<f32> = h_core.iter().map(|v| *v as f32).collect();
46 let density_f32: Vec<f32> = density.iter().map(|v| *v as f32).collect();
47 let eris_f32: Vec<f32> = eris.iter().map(|v| *v as f32).collect();
48
49 let params = pack_uniform_values(&[
51 UniformValue::U32(n_basis as u32),
52 UniformValue::U32(0),
53 UniformValue::U32(0),
54 UniformValue::U32(0),
55 ]);
56
57 let output_seed = vec![0.0f32; n2];
59
60 let wg_size = 16u32;
61 let wg_x = ceil_div_u32(n_basis, wg_size);
62 let wg_y = wg_x;
63
64 let descriptor = ComputeDispatchDescriptor {
65 label: "fock matrix build".to_string(),
66 shader_source: FOCK_BUILD_SHADER.to_string(),
67 entry_point: "main".to_string(),
68 workgroup_count: [wg_x, wg_y, 1],
69 bindings: vec![
70 ComputeBindingDescriptor {
71 label: "h_core".to_string(),
72 kind: ComputeBindingKind::StorageReadOnly,
73 bytes: f32_slice_to_bytes(&h_core_f32),
74 },
75 ComputeBindingDescriptor {
76 label: "density".to_string(),
77 kind: ComputeBindingKind::StorageReadOnly,
78 bytes: f32_slice_to_bytes(&density_f32),
79 },
80 ComputeBindingDescriptor {
81 label: "eris".to_string(),
82 kind: ComputeBindingKind::StorageReadOnly,
83 bytes: f32_slice_to_bytes(&eris_f32),
84 },
85 ComputeBindingDescriptor {
86 label: "params".to_string(),
87 kind: ComputeBindingKind::Uniform,
88 bytes: params,
89 },
90 ComputeBindingDescriptor {
91 label: "output".to_string(),
92 kind: ComputeBindingKind::StorageReadWrite,
93 bytes: f32_slice_to_bytes(&output_seed),
94 },
95 ],
96 };
97
98 let mut result = ctx.run_compute(&descriptor)?;
99 let bytes = result
100 .outputs
101 .pop()
102 .ok_or("No output from Fock build kernel")?;
103
104 if bytes.len() != n2 * 4 {
105 return Err(format!(
106 "Fock output size mismatch: expected {}, got {}",
107 n2 * 4,
108 bytes.len()
109 ));
110 }
111
112 let fock = bytes_to_f64_vec_from_f32(&bytes);
113
114 Ok(fock)
115}
116
117pub const FOCK_BUILD_SHADER: &str = r#"
124struct Params {
125 n_basis: u32,
126 _pad0: u32,
127 _pad1: u32,
128 _pad2: u32,
129};
130
131@group(0) @binding(0) var<storage, read> h_core: array<f32>;
132@group(0) @binding(1) var<storage, read> density: array<f32>;
133@group(0) @binding(2) var<storage, read> eris: array<f32>;
134@group(0) @binding(3) var<uniform> params: Params;
135@group(0) @binding(4) var<storage, read_write> fock: array<f32>;
136
137@compute @workgroup_size(16, 16, 1)
138fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
139 let mu = gid.x;
140 let nu = gid.y;
141 let n = params.n_basis;
142
143 if (mu >= n || nu >= n) {
144 return;
145 }
146
147 let n2 = n * n;
148 var g_mn: f32 = 0.0;
149
150 // G(μ,ν) = Σ_{λσ} P(λ,σ) · [(μν|λσ) - 0.5·(μλ|νσ)]
151 for (var lam: u32 = 0u; lam < n; lam = lam + 1u) {
152 for (var sig: u32 = 0u; sig < n; sig = sig + 1u) {
153 let p_ls = density[lam * n + sig];
154
155 // Coulomb: (μν|λσ)
156 let j_idx = mu * n * n2 + nu * n2 + lam * n + sig;
157 let j_val = eris[j_idx];
158
159 // Exchange: (μλ|νσ)
160 let k_idx = mu * n * n2 + lam * n2 + nu * n + sig;
161 let k_val = eris[k_idx];
162
163 g_mn += p_ls * (j_val - 0.5 * k_val);
164 }
165 }
166
167 fock[mu * n + nu] = h_core[mu * n + nu] + g_mn;
168}
169"#;
170
171#[cfg(test)]
172mod tests {
173 use super::*;
174
175 #[test]
176 fn test_f32_slice_to_bytes() {
177 let data = vec![1.0f32, 2.0, 3.0];
178 let bytes = f32_slice_to_bytes(&data);
179 assert_eq!(bytes.len(), 12);
180 }
181
182 #[test]
183 fn test_gpu_threshold() {
184 let ctx = GpuContext::cpu_fallback();
185 let n = 2;
186 let h = vec![0.0f64; n * n];
187 let d = vec![0.0f64; n * n];
188 let e = vec![0.0f64; n * n * n * n];
189 let result = build_fock_gpu(&ctx, &h, &d, &e, n);
190 assert!(result.is_err());
191 }
192
193 #[test]
194 fn test_dimension_mismatch() {
195 let ctx = GpuContext::cpu_fallback();
196 let result = build_fock_gpu(&ctx, &[0.0; 25], &[0.0; 16], &[0.0; 625], 5);
198 assert!(result.is_err());
199 assert!(result.unwrap_err().contains("mismatch"));
200 }
201}