sci_form/gpu/
gamma_matrix_gpu.rs1use nalgebra::DMatrix;
4
5use super::context::{
6 bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
7 pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
8 ComputeDispatchDescriptor, GpuContext, UniformValue,
9};
10
11const GPU_DISPATCH_THRESHOLD: usize = 2;
12
13pub fn build_gamma_gpu(
14 ctx: &GpuContext,
15 eta: &[f64],
16 positions_bohr: &[[f64; 3]],
17) -> Result<DMatrix<f64>, String> {
18 let n = eta.len();
19 if n < GPU_DISPATCH_THRESHOLD {
20 return Err("Matrix too small for GPU dispatch".to_string());
21 }
22 if positions_bohr.len() != n {
23 return Err("eta/position length mismatch".to_string());
24 }
25
26 let eta_f32: Vec<f32> = eta.iter().map(|value| *value as f32).collect();
27 let params = pack_uniform_values(&[
28 UniformValue::U32(n as u32),
29 UniformValue::U32(0),
30 UniformValue::U32(0),
31 UniformValue::U32(0),
32 ]);
33
34 let output_seed = vec![0.0f32; n * n];
35 let wg_size = 16u32;
36 let wg_x = ceil_div_u32(n, wg_size);
37 let wg_y = wg_x;
38
39 let descriptor = ComputeDispatchDescriptor {
40 label: "gamma matrix".to_string(),
41 shader_source: GAMMA_MATRIX_SHADER.to_string(),
42 entry_point: "main".to_string(),
43 workgroup_count: [wg_x, wg_y, 1],
44 bindings: vec![
45 ComputeBindingDescriptor {
46 label: "eta".to_string(),
47 kind: ComputeBindingKind::StorageReadOnly,
48 bytes: f32_slice_to_bytes(&eta_f32),
49 },
50 ComputeBindingDescriptor {
51 label: "positions".to_string(),
52 kind: ComputeBindingKind::StorageReadOnly,
53 bytes: pack_vec3_positions_f32(positions_bohr),
54 },
55 ComputeBindingDescriptor {
56 label: "params".to_string(),
57 kind: ComputeBindingKind::Uniform,
58 bytes: params,
59 },
60 ComputeBindingDescriptor {
61 label: "output".to_string(),
62 kind: ComputeBindingKind::StorageReadWrite,
63 bytes: f32_slice_to_bytes(&output_seed),
64 },
65 ],
66 };
67
68 let mut result = ctx.run_compute(&descriptor)?;
69 let bytes = result.outputs.pop().ok_or("No output from gamma kernel")?;
70 if bytes.len() != n * n * 4 {
71 return Err(format!(
72 "Gamma output size mismatch: expected {}, got {}",
73 n * n * 4,
74 bytes.len()
75 ));
76 }
77
78 let values = bytes_to_f64_vec_from_f32(&bytes);
79 Ok(DMatrix::from_row_slice(n, n, &values))
80}
81
82pub const GAMMA_MATRIX_SHADER: &str = r#"
83struct AtomPos {
84 x: f32, y: f32, z: f32, _pad: f32,
85};
86
87struct Params {
88 n_atoms: u32,
89 _pad0: u32,
90 _pad1: u32,
91 _pad2: u32,
92};
93
94@group(0) @binding(0) var<storage, read> eta: array<f32>;
95@group(0) @binding(1) var<storage, read> positions: array<AtomPos>;
96@group(0) @binding(2) var<uniform> params: Params;
97@group(0) @binding(3) var<storage, read_write> output: array<f32>;
98
99@compute @workgroup_size(16, 16, 1)
100fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
101 let a = gid.x;
102 let b = gid.y;
103 let n = params.n_atoms;
104
105 if (a >= n || b >= n) {
106 return;
107 }
108
109 if (a == b) {
110 output[a * n + b] = eta[a];
111 return;
112 }
113
114 let pa = positions[a];
115 let pb = positions[b];
116 let dx = pa.x - pb.x;
117 let dy = pa.y - pb.y;
118 let dz = pa.z - pb.z;
119 let r = sqrt(dx * dx + dy * dy + dz * dz);
120
121 output[a * n + b] = 1.0 / r;
122}
123"#;
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn test_gamma_gpu_threshold() {
131 let ctx = GpuContext::cpu_fallback();
132 let result = build_gamma_gpu(&ctx, &[0.5], &[[0.0, 0.0, 0.0]]);
133 assert!(result.is_err());
134 }
135}