Skip to main content

sci_form/gpu/
stda_gpu.rs

1//! GPU-accelerated sTDA Coulomb matrix construction.
2//!
3//! Offloads the O(N²_singles × N²_atoms) J-integral computation to the GPU.
4//! The inner kernel evaluates:
5//!   J_{ia,jb} = Σ_{A,B} q^A_{ia} · γ_{AB} · q^B_{jb}
6//!
7//! This is a GEMM-like operation on the transition charge matrix:
8//!   A_off_diag = 2 · Q^T · Γ · Q
9//! where Q is (n_atoms × n_singles) and Γ is (n_atoms × n_atoms).
10
11use super::context::{
12    ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13};
14
15/// Minimum singles count to justify GPU dispatch.
16const GPU_DISPATCH_THRESHOLD: usize = 100;
17
18/// GPU-accelerated sTDA J-integral matrix: A_{off} = 2 · Q^T · Γ · Q
19///
20/// `q_matrix`: transition charges, shape (n_atoms, n_singles), row-major flat.
21/// `gamma`: damped Coulomb matrix, shape (n_atoms, n_atoms), row-major flat.
22/// `n_atoms`: number of atoms.
23/// `n_singles`: number of single excitations.
24///
25/// Returns the off-diagonal contribution to the A matrix (n_singles × n_singles), row-major.
26pub fn compute_stda_j_matrix_gpu(
27    ctx: &GpuContext,
28    q_matrix: &[f64],
29    gamma: &[f64],
30    n_atoms: usize,
31    n_singles: usize,
32) -> Result<Vec<f64>, String> {
33    if n_singles < GPU_DISPATCH_THRESHOLD || !ctx.capabilities.gpu_available {
34        return compute_stda_j_matrix_cpu(q_matrix, gamma, n_atoms, n_singles);
35    }
36
37    // For GPU path: compute GammaQ = Γ · Q (n_atoms × n_singles)
38    // Then A_off = 2 · Q^T · GammaQ (n_singles × n_singles)
39    //
40    // Both are matrix multiplies which map well to GPU compute shaders.
41    // For now, use the GPU's general matrix multiply dispatch.
42
43    let q_f32: Vec<f32> = q_matrix.iter().map(|&x| x as f32).collect();
44    let gamma_f32: Vec<f32> = gamma.iter().map(|&x| x as f32).collect();
45
46    // Step 1: GammaQ = Γ · Q
47    let gamma_q_bytes = vec![0u8; n_atoms * n_singles * 4];
48
49    let dispatch = ComputeDispatchDescriptor {
50        label: "stda_gamma_q".to_string(),
51        shader_source: MATMUL_SHADER.to_string(),
52        entry_point: "main".to_string(),
53        workgroup_count: [
54            n_atoms.div_ceil(16) as u32,
55            n_singles.div_ceil(16) as u32,
56            1,
57        ],
58        bindings: vec![
59            ComputeBindingDescriptor {
60                label: "gamma".to_string(),
61                kind: ComputeBindingKind::StorageReadOnly,
62                bytes: bytemuck_cast_f32(&gamma_f32),
63            },
64            ComputeBindingDescriptor {
65                label: "q".to_string(),
66                kind: ComputeBindingKind::StorageReadOnly,
67                bytes: bytemuck_cast_f32(&q_f32),
68            },
69            ComputeBindingDescriptor {
70                label: "result".to_string(),
71                kind: ComputeBindingKind::StorageReadWrite,
72                bytes: gamma_q_bytes,
73            },
74            ComputeBindingDescriptor {
75                label: "dims".to_string(),
76                kind: ComputeBindingKind::Uniform,
77                bytes: pack_dims(n_atoms as u32, n_atoms as u32, n_singles as u32),
78            },
79        ],
80    };
81
82    let gamma_q_result = ctx
83        .run_compute(&dispatch)?
84        .outputs
85        .into_iter()
86        .last()
87        .unwrap_or_default();
88
89    // Step 2: A_off = 2 · Q^T · GammaQ
90    let result_bytes = vec![0u8; n_singles * n_singles * 4];
91
92    let dispatch2 = ComputeDispatchDescriptor {
93        label: "stda_qt_gamma_q".to_string(),
94        shader_source: MATMUL_TRANSPOSE_SHADER.to_string(),
95        entry_point: "main".to_string(),
96        workgroup_count: [
97            n_singles.div_ceil(16) as u32,
98            n_singles.div_ceil(16) as u32,
99            1,
100        ],
101        bindings: vec![
102            ComputeBindingDescriptor {
103                label: "q".to_string(),
104                kind: ComputeBindingKind::StorageReadOnly,
105                bytes: bytemuck_cast_f32(&q_f32),
106            },
107            ComputeBindingDescriptor {
108                label: "gamma_q".to_string(),
109                kind: ComputeBindingKind::StorageReadOnly,
110                bytes: gamma_q_result,
111            },
112            ComputeBindingDescriptor {
113                label: "result".to_string(),
114                kind: ComputeBindingKind::StorageReadWrite,
115                bytes: result_bytes,
116            },
117            ComputeBindingDescriptor {
118                label: "dims".to_string(),
119                kind: ComputeBindingKind::Uniform,
120                bytes: pack_dims(n_atoms as u32, n_singles as u32, n_singles as u32),
121            },
122        ],
123    };
124
125    let a_off_bytes = ctx
126        .run_compute(&dispatch2)?
127        .outputs
128        .into_iter()
129        .last()
130        .unwrap_or_default();
131
132    // Convert f32 result back to f64 with 2× scaling
133    let a_off_f32: &[f32] = bytemuck_cast_from_u8(&a_off_bytes);
134    Ok(a_off_f32.iter().map(|&x| 2.0 * x as f64).collect())
135}
136
137/// CPU fallback for sTDA J-integral matrix.
138fn compute_stda_j_matrix_cpu(
139    q_matrix: &[f64],
140    gamma: &[f64],
141    n_atoms: usize,
142    n_singles: usize,
143) -> Result<Vec<f64>, String> {
144    let mut result = vec![0.0; n_singles * n_singles];
145
146    // A[ia, jb] = 2 * Σ_{A,B} q[A, ia] * gamma[A,B] * q[B, jb]
147    for ia in 0..n_singles {
148        for jb in 0..=ia {
149            let mut val = 0.0;
150            for a in 0..n_atoms {
151                let q_a_ia = q_matrix[a * n_singles + ia];
152                if q_a_ia.abs() < 1e-12 {
153                    continue;
154                }
155                for b in 0..n_atoms {
156                    val += q_a_ia * gamma[a * n_atoms + b] * q_matrix[b * n_singles + jb];
157                }
158            }
159            result[ia * n_singles + jb] = 2.0 * val;
160            result[jb * n_singles + ia] = 2.0 * val;
161        }
162    }
163
164    Ok(result)
165}
166
167fn bytemuck_cast_f32(data: &[f32]) -> Vec<u8> {
168    data.iter().flat_map(|x| x.to_ne_bytes()).collect()
169}
170
171fn bytemuck_cast_from_u8(data: &[u8]) -> &[f32] {
172    // Safety: data alignment is guaranteed by GPU buffer alignment
173    let (prefix, result, suffix) = unsafe { data.align_to::<f32>() };
174    if prefix.is_empty() && suffix.is_empty() {
175        result
176    } else {
177        &[]
178    }
179}
180
181fn pack_dims(m: u32, k: u32, n: u32) -> Vec<u8> {
182    let mut bytes = Vec::with_capacity(16);
183    bytes.extend_from_slice(&m.to_ne_bytes());
184    bytes.extend_from_slice(&k.to_ne_bytes());
185    bytes.extend_from_slice(&n.to_ne_bytes());
186    bytes.extend_from_slice(&0u32.to_ne_bytes()); // padding
187    bytes
188}
189
190/// WGSL shader for general matrix multiply: C = A × B
191const MATMUL_SHADER: &str = r#"
192struct Dims { M: u32, K: u32, N: u32, _pad: u32 }
193
194@group(0) @binding(0) var<storage, read> a: array<f32>;
195@group(0) @binding(1) var<storage, read> b: array<f32>;
196@group(0) @binding(2) var<storage, read_write> c: array<f32>;
197@group(0) @binding(3) var<uniform> dims: Dims;
198
199@compute @workgroup_size(16, 16)
200fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
201    let row = gid.x;
202    let col = gid.y;
203    if row >= dims.M || col >= dims.N { return; }
204
205    var sum: f32 = 0.0;
206    for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
207        sum = sum + a[row * dims.K + k] * b[k * dims.N + col];
208    }
209    c[row * dims.N + col] = sum;
210}
211"#;
212
213/// WGSL shader for transpose-multiply: C = A^T × B
214const MATMUL_TRANSPOSE_SHADER: &str = r#"
215struct Dims { K: u32, M: u32, N: u32, _pad: u32 }
216
217@group(0) @binding(0) var<storage, read> a: array<f32>;
218@group(0) @binding(1) var<storage, read> b: array<f32>;
219@group(0) @binding(2) var<storage, read_write> c: array<f32>;
220@group(0) @binding(3) var<uniform> dims: Dims;
221
222@compute @workgroup_size(16, 16)
223fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
224    let row = gid.x;
225    let col = gid.y;
226    if row >= dims.M || col >= dims.N { return; }
227
228    var sum: f32 = 0.0;
229    for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
230        sum = sum + a[k * dims.M + row] * b[k * dims.N + col];
231    }
232    c[row * dims.N + col] = sum;
233}
234"#;