1use super::context::{
11 ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
12};
13use crate::scf::basis::BasisSet;
14use crate::scf::kinetic_matrix::build_kinetic_matrix;
15use crate::scf::nuclear_matrix::build_nuclear_matrix;
16use crate::scf::overlap_matrix::build_overlap_matrix;
17
18const MAX_PRIMITIVES: usize = 3;
20
21const GPU_DISPATCH_THRESHOLD: usize = 6;
23
24pub struct OneElectronResult {
26 pub overlap: Vec<f64>,
28 pub kinetic: Vec<f64>,
30 pub nuclear: Vec<f64>,
32 pub n_basis: usize,
34 pub used_gpu: bool,
36 pub backend: String,
38 pub note: String,
40}
41
42fn flatten_matrix_row_major(matrix: &nalgebra::DMatrix<f64>) -> Vec<f64> {
43 (0..matrix.nrows())
44 .flat_map(|i| (0..matrix.ncols()).map(move |j| matrix[(i, j)]))
45 .collect()
46}
47
48fn compute_one_electron_cpu_exact(
49 basis: &BasisSet,
50 elements: &[u8],
51 positions_bohr: &[[f64; 3]],
52 note: impl Into<String>,
53) -> OneElectronResult {
54 let overlap = build_overlap_matrix(basis);
55 let kinetic = build_kinetic_matrix(basis);
56 let nuclear = build_nuclear_matrix(basis, elements, positions_bohr);
57
58 OneElectronResult {
59 overlap: flatten_matrix_row_major(&overlap),
60 kinetic: flatten_matrix_row_major(&kinetic),
61 nuclear: flatten_matrix_row_major(&nuclear),
62 n_basis: basis.n_basis,
63 used_gpu: false,
64 backend: "CPU-exact".to_string(),
65 note: note.into(),
66 }
67}
68
69fn basis_supports_exact_gpu_kernel(basis: &BasisSet) -> bool {
70 basis.functions.iter().all(|bf| bf.l_total == 0)
71}
72
73fn pack_one_electron_data(
84 basis: &BasisSet,
85 elements: &[u8],
86 positions_bohr: &[[f64; 3]],
87) -> (Vec<u8>, Vec<u8>, Vec<u8>) {
88 let mut basis_bytes = Vec::with_capacity(basis.n_basis * 32);
89 let mut prim_bytes = Vec::with_capacity(basis.n_basis * MAX_PRIMITIVES * 8);
90
91 for bf in &basis.functions {
92 basis_bytes.extend_from_slice(&(bf.center[0] as f32).to_ne_bytes());
93 basis_bytes.extend_from_slice(&(bf.center[1] as f32).to_ne_bytes());
94 basis_bytes.extend_from_slice(&(bf.center[2] as f32).to_ne_bytes());
95 basis_bytes.extend_from_slice(&bf.angular[0].to_ne_bytes());
96 basis_bytes.extend_from_slice(&bf.angular[1].to_ne_bytes());
97 basis_bytes.extend_from_slice(&bf.angular[2].to_ne_bytes());
98 basis_bytes.extend_from_slice(&(bf.primitives.len() as u32).to_ne_bytes());
99 basis_bytes.extend_from_slice(&0u32.to_ne_bytes());
100
101 for i in 0..MAX_PRIMITIVES {
102 if i < bf.primitives.len() {
103 let norm = crate::scf::basis::BasisFunction::normalization(
104 bf.primitives[i].alpha,
105 bf.angular[0],
106 bf.angular[1],
107 bf.angular[2],
108 );
109 prim_bytes.extend_from_slice(&(bf.primitives[i].alpha as f32).to_ne_bytes());
110 prim_bytes.extend_from_slice(
111 &((bf.primitives[i].coefficient * norm) as f32).to_ne_bytes(),
112 );
113 } else {
114 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
115 prim_bytes.extend_from_slice(&0.0f32.to_ne_bytes());
116 }
117 }
118 }
119
120 let mut atom_bytes = Vec::with_capacity(elements.len() * 16);
122 for (i, &z) in elements.iter().enumerate() {
123 atom_bytes.extend_from_slice(&(positions_bohr[i][0] as f32).to_ne_bytes());
124 atom_bytes.extend_from_slice(&(positions_bohr[i][1] as f32).to_ne_bytes());
125 atom_bytes.extend_from_slice(&(positions_bohr[i][2] as f32).to_ne_bytes());
126 atom_bytes.extend_from_slice(&(z as f32).to_ne_bytes());
127 }
128
129 (basis_bytes, prim_bytes, atom_bytes)
130}
131
132pub fn compute_one_electron_gpu(
136 ctx: &GpuContext,
137 basis: &BasisSet,
138 elements: &[u8],
139 positions_bohr: &[[f64; 3]],
140) -> Result<OneElectronResult, String> {
141 let n = basis.n_basis;
142 let n_atoms = elements.len();
143
144 if n < GPU_DISPATCH_THRESHOLD {
145 return Err("Basis too small for GPU dispatch".to_string());
146 }
147
148 if !basis_supports_exact_gpu_kernel(basis) {
149 return Ok(compute_one_electron_cpu_exact(
150 basis,
151 elements,
152 positions_bohr,
153 "Fell back to exact CPU one-electron builders because the current WGSL kernel only supports pure s-type basis functions.",
154 ));
155 }
156
157 let (basis_bytes, prim_bytes, atom_bytes) =
158 pack_one_electron_data(basis, elements, positions_bohr);
159
160 let mut params = Vec::with_capacity(16);
162 params.extend_from_slice(&(n as u32).to_ne_bytes());
163 params.extend_from_slice(&(n_atoms as u32).to_ne_bytes());
164 params.extend_from_slice(&0u32.to_ne_bytes());
165 params.extend_from_slice(&0u32.to_ne_bytes());
166
167 let output_size = 3 * n * n;
169 let output_seed = vec![0.0f32; output_size];
170 let output_bytes: Vec<u8> = output_seed.iter().flat_map(|v| v.to_ne_bytes()).collect();
171
172 let wg_size = 16u32;
173 let wg_x = (n as u32).div_ceil(wg_size);
174 let wg_y = wg_x;
175
176 let descriptor = ComputeDispatchDescriptor {
177 label: "one-electron matrices".to_string(),
178 shader_source: ONE_ELECTRON_SHADER.to_string(),
179 entry_point: "main".to_string(),
180 workgroup_count: [wg_x, wg_y, 1],
181 bindings: vec![
182 ComputeBindingDescriptor {
183 label: "basis".to_string(),
184 kind: ComputeBindingKind::StorageReadOnly,
185 bytes: basis_bytes,
186 },
187 ComputeBindingDescriptor {
188 label: "primitives".to_string(),
189 kind: ComputeBindingKind::StorageReadOnly,
190 bytes: prim_bytes,
191 },
192 ComputeBindingDescriptor {
193 label: "atoms".to_string(),
194 kind: ComputeBindingKind::StorageReadOnly,
195 bytes: atom_bytes,
196 },
197 ComputeBindingDescriptor {
198 label: "params".to_string(),
199 kind: ComputeBindingKind::Uniform,
200 bytes: params,
201 },
202 ComputeBindingDescriptor {
203 label: "output".to_string(),
204 kind: ComputeBindingKind::StorageReadWrite,
205 bytes: output_bytes,
206 },
207 ],
208 };
209
210 let mut result = ctx.run_compute(&descriptor)?;
211 let bytes = result
212 .outputs
213 .pop()
214 .ok_or("No output from one-electron kernel")?;
215
216 if bytes.len() != output_size * 4 {
217 return Err(format!(
218 "Output size mismatch: expected {}, got {}",
219 output_size * 4,
220 bytes.len()
221 ));
222 }
223
224 let all_f64: Vec<f64> = bytes
225 .chunks_exact(4)
226 .map(|c| f32::from_ne_bytes([c[0], c[1], c[2], c[3]]) as f64)
227 .collect();
228
229 let n2 = n * n;
230 Ok(OneElectronResult {
231 overlap: all_f64[..n2].to_vec(),
232 kinetic: all_f64[n2..2 * n2].to_vec(),
233 nuclear: all_f64[2 * n2..].to_vec(),
234 n_basis: n,
235 used_gpu: true,
236 backend: ctx.capabilities.backend.clone(),
237 note: "Executed WGSL one-electron kernel for a pure s-type basis.".to_string(),
238 })
239}
240
241pub const ONE_ELECTRON_SHADER: &str = r#"
251struct BasisFunc {
252 cx: f32, cy: f32, cz: f32,
253 lx: u32, ly: u32, lz: u32,
254 n_prims: u32, _pad: u32,
255};
256
257struct Atom {
258 x: f32, y: f32, z: f32,
259 charge: f32,
260};
261
262struct Params {
263 n_basis: u32,
264 n_atoms: u32,
265 _pad0: u32,
266 _pad1: u32,
267};
268
269@group(0) @binding(0) var<storage, read> basis: array<BasisFunc>;
270@group(0) @binding(1) var<storage, read> primitives: array<vec2<f32>>;
271@group(0) @binding(2) var<storage, read> atoms: array<Atom>;
272@group(0) @binding(3) var<uniform> params: Params;
273@group(0) @binding(4) var<storage, read_write> output: array<f32>;
274// Output layout: [S_00..S_{nn}] [T_00..T_{nn}] [V_00..V_{nn}]
275
276fn boys_f0(x: f32) -> f32 {
277 if (x < 1e-7) {
278 return 1.0;
279 }
280 if (x > 30.0) {
281 return 0.8862269 / sqrt(x);
282 }
283 var sum: f32 = 1.0;
284 var term: f32 = 1.0;
285 for (var k: u32 = 1u; k < 50u; k = k + 1u) {
286 term *= 2.0 * x / f32(2u * k + 1u);
287 sum += term;
288 if (abs(term) < 1e-6 * abs(sum)) {
289 break;
290 }
291 }
292 return exp(-x) * sum;
293}
294
295@compute @workgroup_size(16, 16, 1)
296fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
297 let mu = gid.x;
298 let nu = gid.y;
299 let n = params.n_basis;
300 let n2 = n * n;
301
302 if (mu >= n || nu >= n) {
303 return;
304 }
305
306 // Only compute upper triangle (mu <= nu), then mirror
307 let compute_mu = min(mu, nu);
308 let compute_nu = max(mu, nu);
309
310 let bf_a = basis[compute_mu];
311 let bf_b = basis[compute_nu];
312
313 var s_total: f32 = 0.0;
314 var t_total: f32 = 0.0;
315 var v_total: f32 = 0.0;
316
317 // Contract over primitives
318 for (var pa: u32 = 0u; pa < bf_a.n_prims; pa = pa + 1u) {
319 let prim_a = primitives[compute_mu * 3u + pa];
320 let alpha = prim_a.x;
321 let ca = prim_a.y;
322
323 for (var pb: u32 = 0u; pb < bf_b.n_prims; pb = pb + 1u) {
324 let prim_b = primitives[compute_nu * 3u + pb];
325 let beta = prim_b.x;
326 let cb = prim_b.y;
327
328 let p = alpha + beta;
329 let mu_ab = alpha * beta / p;
330 let ab2 = (bf_a.cx - bf_b.cx) * (bf_a.cx - bf_b.cx)
331 + (bf_a.cy - bf_b.cy) * (bf_a.cy - bf_b.cy)
332 + (bf_a.cz - bf_b.cz) * (bf_a.cz - bf_b.cz);
333 let k_ab = exp(-mu_ab * ab2);
334
335 // Overlap: S = (π/p)^{3/2} · K_ab
336 let pi = 3.14159265;
337 let s_prim = pow(pi / p, 1.5) * k_ab;
338 s_total += ca * cb * s_prim;
339
340 // Kinetic: T = μ(3 - 2μ·|AB|²) · S / p [s-type simplification]
341 let t_prim = mu_ab * (3.0 - 2.0 * mu_ab * ab2) * s_prim;
342 t_total += ca * cb * t_prim;
343
344 // Nuclear attraction: V = -Σ_C Z_C · 2π/p · K_ab · F₀(p·|PC|²)
345 let px = (alpha * bf_a.cx + beta * bf_b.cx) / p;
346 let py = (alpha * bf_a.cy + beta * bf_b.cy) / p;
347 let pz = (alpha * bf_a.cz + beta * bf_b.cz) / p;
348
349 for (var c: u32 = 0u; c < params.n_atoms; c = c + 1u) {
350 let atom = atoms[c];
351 let pc2 = (px - atom.x) * (px - atom.x)
352 + (py - atom.y) * (py - atom.y)
353 + (pz - atom.z) * (pz - atom.z);
354 let v_prim = -atom.charge * 2.0 * pi / p * k_ab * boys_f0(p * pc2);
355 v_total += ca * cb * v_prim;
356 }
357 }
358 }
359
360 // Write to output: S at offset 0, T at offset n², V at offset 2n²
361 output[mu * n + nu] = s_total;
362 output[n2 + mu * n + nu] = t_total;
363 output[2u * n2 + mu * n + nu] = v_total;
364}
365"#;
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_pack_one_electron() {
373 let basis =
374 crate::scf::basis::BasisSet::sto3g(&[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
375 let (basis_bytes, prim_bytes, atom_bytes) =
376 pack_one_electron_data(&basis, &[1, 1], &[[0.0, 0.0, 0.0], [1.4, 0.0, 0.0]]);
377 assert_eq!(basis_bytes.len(), basis.n_basis * 32);
378 assert_eq!(prim_bytes.len(), basis.n_basis * MAX_PRIMITIVES * 8);
379 assert_eq!(atom_bytes.len(), 2 * 16);
380 }
381
382 #[test]
383 fn test_gpu_threshold() {
384 let ctx = GpuContext::cpu_fallback();
385 let basis = crate::scf::basis::BasisSet::sto3g(&[1], &[[0.0, 0.0, 0.0]]);
386 let result = compute_one_electron_gpu(&ctx, &basis, &[1], &[[0.0, 0.0, 0.0]]);
387 assert!(result.is_err());
388 }
389
390 #[test]
391 fn test_mixed_angular_momentum_falls_back_to_exact_cpu() {
392 let elements = [8u8, 1, 1];
393 let positions = [
394 [0.0, 0.0, 0.117],
395 [0.0, 0.757, -0.469],
396 [0.0, -0.757, -0.469],
397 ];
398 let basis = crate::scf::basis::BasisSet::sto3g(&elements, &positions);
399 let ctx = GpuContext::cpu_fallback();
400
401 let result = compute_one_electron_gpu(&ctx, &basis, &elements, &positions)
402 .expect("mixed-angular basis should fall back to CPU");
403
404 let overlap = build_overlap_matrix(&basis);
405 let overlap_flat = flatten_matrix_row_major(&overlap);
406 assert!(!result.used_gpu);
407 assert_eq!(result.backend, "CPU-exact");
408 assert!(result.note.contains("pure s-type"));
409 assert_eq!(result.overlap.len(), overlap_flat.len());
410 for (lhs, rhs) in result.overlap.iter().zip(overlap_flat.iter()) {
411 assert!((lhs - rhs).abs() < 1e-12);
412 }
413 }
414}