Skip to main content

sci_form/gpu/
gamma_matrix_gpu.rs

1//! GPU-accelerated SCC-DFTB gamma-matrix construction.
2
3use nalgebra::DMatrix;
4
5use super::context::{
6    bytes_to_f64_vec_from_f32, ceil_div_u32, f32_slice_to_bytes, pack_uniform_values,
7    pack_vec3_positions_f32, ComputeBindingDescriptor, ComputeBindingKind,
8    ComputeDispatchDescriptor, GpuContext, UniformValue,
9};
10
11const GPU_DISPATCH_THRESHOLD: usize = 2;
12
13pub fn build_gamma_gpu(
14    ctx: &GpuContext,
15    eta: &[f64],
16    positions_bohr: &[[f64; 3]],
17) -> Result<DMatrix<f64>, String> {
18    let n = eta.len();
19    if n < GPU_DISPATCH_THRESHOLD {
20        return Err("Matrix too small for GPU dispatch".to_string());
21    }
22    if positions_bohr.len() != n {
23        return Err("eta/position length mismatch".to_string());
24    }
25
26    let eta_f32: Vec<f32> = eta.iter().map(|value| *value as f32).collect();
27    let params = pack_uniform_values(&[
28        UniformValue::U32(n as u32),
29        UniformValue::U32(0),
30        UniformValue::U32(0),
31        UniformValue::U32(0),
32    ]);
33
34    let output_seed = vec![0.0f32; n * n];
35    let wg_size = 16u32;
36    let wg_x = ceil_div_u32(n, wg_size);
37    let wg_y = wg_x;
38
39    let descriptor = ComputeDispatchDescriptor {
40        label: "gamma matrix".to_string(),
41        shader_source: GAMMA_MATRIX_SHADER.to_string(),
42        entry_point: "main".to_string(),
43        workgroup_count: [wg_x, wg_y, 1],
44        bindings: vec![
45            ComputeBindingDescriptor {
46                label: "eta".to_string(),
47                kind: ComputeBindingKind::StorageReadOnly,
48                bytes: f32_slice_to_bytes(&eta_f32),
49            },
50            ComputeBindingDescriptor {
51                label: "positions".to_string(),
52                kind: ComputeBindingKind::StorageReadOnly,
53                bytes: pack_vec3_positions_f32(positions_bohr),
54            },
55            ComputeBindingDescriptor {
56                label: "params".to_string(),
57                kind: ComputeBindingKind::Uniform,
58                bytes: params,
59            },
60            ComputeBindingDescriptor {
61                label: "output".to_string(),
62                kind: ComputeBindingKind::StorageReadWrite,
63                bytes: f32_slice_to_bytes(&output_seed),
64            },
65        ],
66    };
67
68    let mut result = ctx.run_compute(&descriptor)?;
69    let bytes = result.outputs.pop().ok_or("No output from gamma kernel")?;
70    if bytes.len() != n * n * 4 {
71        return Err(format!(
72            "Gamma output size mismatch: expected {}, got {}",
73            n * n * 4,
74            bytes.len()
75        ));
76    }
77
78    let values = bytes_to_f64_vec_from_f32(&bytes);
79    Ok(DMatrix::from_row_slice(n, n, &values))
80}
81
82pub const GAMMA_MATRIX_SHADER: &str = r#"
83struct AtomPos {
84    x: f32, y: f32, z: f32, _pad: f32,
85};
86
87struct Params {
88    n_atoms: u32,
89    _pad0: u32,
90    _pad1: u32,
91    _pad2: u32,
92};
93
94@group(0) @binding(0) var<storage, read> eta: array<f32>;
95@group(0) @binding(1) var<storage, read> positions: array<AtomPos>;
96@group(0) @binding(2) var<uniform> params: Params;
97@group(0) @binding(3) var<storage, read_write> output: array<f32>;
98
99@compute @workgroup_size(16, 16, 1)
100fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
101    let a = gid.x;
102    let b = gid.y;
103    let n = params.n_atoms;
104
105    if (a >= n || b >= n) {
106        return;
107    }
108
109    if (a == b) {
110        output[a * n + b] = eta[a];
111        return;
112    }
113
114    let pa = positions[a];
115    let pb = positions[b];
116    let dx = pa.x - pb.x;
117    let dy = pa.y - pb.y;
118    let dz = pa.z - pb.z;
119    let r = sqrt(dx * dx + dy * dy + dz * dz);
120
121    output[a * n + b] = 1.0 / r;
122}
123"#;
124
125#[cfg(test)]
126mod tests {
127    use super::*;
128
129    #[test]
130    fn test_gamma_gpu_threshold() {
131        let ctx = GpuContext::cpu_fallback();
132        let result = build_gamma_gpu(&ctx, &[0.5], &[[0.0, 0.0, 0.0]]);
133        assert!(result.is_err());
134    }
135}