1use super::context::{
12 ComputeBindingDescriptor, ComputeBindingKind, ComputeDispatchDescriptor, GpuContext,
13};
14
15const GPU_DISPATCH_THRESHOLD: usize = 100;
17
18pub fn compute_stda_j_matrix_gpu(
27 ctx: &GpuContext,
28 q_matrix: &[f64],
29 gamma: &[f64],
30 n_atoms: usize,
31 n_singles: usize,
32) -> Result<Vec<f64>, String> {
33 if n_singles < GPU_DISPATCH_THRESHOLD || !ctx.capabilities.gpu_available {
34 return compute_stda_j_matrix_cpu(q_matrix, gamma, n_atoms, n_singles);
35 }
36
37 let q_f32: Vec<f32> = q_matrix.iter().map(|&x| x as f32).collect();
44 let gamma_f32: Vec<f32> = gamma.iter().map(|&x| x as f32).collect();
45
46 let gamma_q_bytes = vec![0u8; n_atoms * n_singles * 4];
48
49 let dispatch = ComputeDispatchDescriptor {
50 label: "stda_gamma_q".to_string(),
51 shader_source: MATMUL_SHADER.to_string(),
52 entry_point: "main".to_string(),
53 workgroup_count: [
54 n_atoms.div_ceil(16) as u32,
55 n_singles.div_ceil(16) as u32,
56 1,
57 ],
58 bindings: vec![
59 ComputeBindingDescriptor {
60 label: "gamma".to_string(),
61 kind: ComputeBindingKind::StorageReadOnly,
62 bytes: bytemuck_cast_f32(&gamma_f32),
63 },
64 ComputeBindingDescriptor {
65 label: "q".to_string(),
66 kind: ComputeBindingKind::StorageReadOnly,
67 bytes: bytemuck_cast_f32(&q_f32),
68 },
69 ComputeBindingDescriptor {
70 label: "result".to_string(),
71 kind: ComputeBindingKind::StorageReadWrite,
72 bytes: gamma_q_bytes,
73 },
74 ComputeBindingDescriptor {
75 label: "dims".to_string(),
76 kind: ComputeBindingKind::Uniform,
77 bytes: pack_dims(n_atoms as u32, n_atoms as u32, n_singles as u32),
78 },
79 ],
80 };
81
82 let gamma_q_result = ctx
83 .run_compute(&dispatch)?
84 .outputs
85 .into_iter()
86 .last()
87 .unwrap_or_default();
88
89 let result_bytes = vec![0u8; n_singles * n_singles * 4];
91
92 let dispatch2 = ComputeDispatchDescriptor {
93 label: "stda_qt_gamma_q".to_string(),
94 shader_source: MATMUL_TRANSPOSE_SHADER.to_string(),
95 entry_point: "main".to_string(),
96 workgroup_count: [
97 n_singles.div_ceil(16) as u32,
98 n_singles.div_ceil(16) as u32,
99 1,
100 ],
101 bindings: vec![
102 ComputeBindingDescriptor {
103 label: "q".to_string(),
104 kind: ComputeBindingKind::StorageReadOnly,
105 bytes: bytemuck_cast_f32(&q_f32),
106 },
107 ComputeBindingDescriptor {
108 label: "gamma_q".to_string(),
109 kind: ComputeBindingKind::StorageReadOnly,
110 bytes: gamma_q_result,
111 },
112 ComputeBindingDescriptor {
113 label: "result".to_string(),
114 kind: ComputeBindingKind::StorageReadWrite,
115 bytes: result_bytes,
116 },
117 ComputeBindingDescriptor {
118 label: "dims".to_string(),
119 kind: ComputeBindingKind::Uniform,
120 bytes: pack_dims(n_atoms as u32, n_singles as u32, n_singles as u32),
121 },
122 ],
123 };
124
125 let a_off_bytes = ctx
126 .run_compute(&dispatch2)?
127 .outputs
128 .into_iter()
129 .last()
130 .unwrap_or_default();
131
132 let a_off_f32: &[f32] = bytemuck_cast_from_u8(&a_off_bytes);
134 Ok(a_off_f32.iter().map(|&x| 2.0 * x as f64).collect())
135}
136
137fn compute_stda_j_matrix_cpu(
139 q_matrix: &[f64],
140 gamma: &[f64],
141 n_atoms: usize,
142 n_singles: usize,
143) -> Result<Vec<f64>, String> {
144 let mut result = vec![0.0; n_singles * n_singles];
145
146 for ia in 0..n_singles {
148 for jb in 0..=ia {
149 let mut val = 0.0;
150 for a in 0..n_atoms {
151 let q_a_ia = q_matrix[a * n_singles + ia];
152 if q_a_ia.abs() < 1e-12 {
153 continue;
154 }
155 for b in 0..n_atoms {
156 val += q_a_ia * gamma[a * n_atoms + b] * q_matrix[b * n_singles + jb];
157 }
158 }
159 result[ia * n_singles + jb] = 2.0 * val;
160 result[jb * n_singles + ia] = 2.0 * val;
161 }
162 }
163
164 Ok(result)
165}
166
167fn bytemuck_cast_f32(data: &[f32]) -> Vec<u8> {
168 data.iter().flat_map(|x| x.to_ne_bytes()).collect()
169}
170
171fn bytemuck_cast_from_u8(data: &[u8]) -> &[f32] {
172 let (prefix, result, suffix) = unsafe { data.align_to::<f32>() };
174 if prefix.is_empty() && suffix.is_empty() {
175 result
176 } else {
177 &[]
178 }
179}
180
181fn pack_dims(m: u32, k: u32, n: u32) -> Vec<u8> {
182 let mut bytes = Vec::with_capacity(16);
183 bytes.extend_from_slice(&m.to_ne_bytes());
184 bytes.extend_from_slice(&k.to_ne_bytes());
185 bytes.extend_from_slice(&n.to_ne_bytes());
186 bytes.extend_from_slice(&0u32.to_ne_bytes()); bytes
188}
189
190const MATMUL_SHADER: &str = r#"
192struct Dims { M: u32, K: u32, N: u32, _pad: u32 }
193
194@group(0) @binding(0) var<storage, read> a: array<f32>;
195@group(0) @binding(1) var<storage, read> b: array<f32>;
196@group(0) @binding(2) var<storage, read_write> c: array<f32>;
197@group(0) @binding(3) var<uniform> dims: Dims;
198
199@compute @workgroup_size(16, 16)
200fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
201 let row = gid.x;
202 let col = gid.y;
203 if row >= dims.M || col >= dims.N { return; }
204
205 var sum: f32 = 0.0;
206 for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
207 sum = sum + a[row * dims.K + k] * b[k * dims.N + col];
208 }
209 c[row * dims.N + col] = sum;
210}
211"#;
212
213const MATMUL_TRANSPOSE_SHADER: &str = r#"
215struct Dims { K: u32, M: u32, N: u32, _pad: u32 }
216
217@group(0) @binding(0) var<storage, read> a: array<f32>;
218@group(0) @binding(1) var<storage, read> b: array<f32>;
219@group(0) @binding(2) var<storage, read_write> c: array<f32>;
220@group(0) @binding(3) var<uniform> dims: Dims;
221
222@compute @workgroup_size(16, 16)
223fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
224 let row = gid.x;
225 let col = gid.y;
226 if row >= dims.M || col >= dims.N { return; }
227
228 var sum: f32 = 0.0;
229 for (var k: u32 = 0u; k < dims.K; k = k + 1u) {
230 sum = sum + a[k * dims.M + row] * b[k * dims.N + col];
231 }
232 c[row * dims.N + col] = sum;
233}
234"#;