1use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20pub struct ShaderCache {
25 cache: Mutex<HashMap<String, Arc<wgpu::ShaderModule>>>,
26}
27
28impl ShaderCache {
29 #[must_use]
31 pub fn new() -> Self {
32 Self { cache: Mutex::new(HashMap::new()) }
33 }
34
35 pub fn get_or_insert(
48 &self,
49 key: &str,
50 device: &wgpu::Device,
51 shader_source: &str,
52 ) -> Arc<wgpu::ShaderModule> {
53 let mut cache = self
54 .cache
55 .lock()
56 .expect("Shader cache mutex poisoned (should never happen in normal operation)");
57
58 if !cache.contains_key(key) {
59 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
60 label: Some(key),
61 source: wgpu::ShaderSource::Wgsl(shader_source.into()),
62 });
63 cache.insert(key.to_string(), Arc::new(shader));
64 }
65
66 Arc::clone(cache.get(key).expect("Shader must exist in cache after insertion"))
68 }
69
70 #[must_use]
75 pub fn stats(&self) -> (usize, usize) {
76 let cache = self
77 .cache
78 .lock()
79 .expect("Shader cache mutex poisoned (should never happen in normal operation)");
80 (cache.len(), cache.capacity())
81 }
82}
83
84impl Default for ShaderCache {
85 fn default() -> Self {
86 Self::new()
87 }
88}
89
90pub struct JitCompiler {
95 cache: ShaderCache,
96}
97
98impl JitCompiler {
99 #[must_use]
101 pub fn new() -> Self {
102 Self { cache: ShaderCache::new() }
103 }
104
105 #[must_use]
122 pub fn generate_fused_filter_sum(&self, filter_threshold: i32, filter_op: &str) -> String {
123 let wgsl_op = match filter_op {
125 "lt" => "<",
126 "eq" => "==",
127 "gte" => ">=",
128 "lte" => "<=",
129 "ne" => "!=",
130 _ => ">", };
132
133 format!(
134 r"
135@group(0) @binding(0) var<storage, read> input: array<i32>;
136@group(0) @binding(1) var<storage, read_write> output: array<atomic<i32>>;
137
138var<workgroup> shared_data: array<i32, 256>;
139
140@compute @workgroup_size(256)
141fn fused_filter_sum(@builtin(global_invocation_id) global_id: vec3<u32>,
142 @builtin(local_invocation_id) local_id: vec3<u32>) {{
143 let tid = local_id.x;
144 let gid = global_id.x;
145 let input_size = arrayLength(&input);
146
147 // Fused filter + load: Apply filter predicate during load
148 // Eliminates intermediate buffer write (Muda elimination)
149 var value: i32 = 0;
150 if (gid < input_size) {{
151 let data = input[gid];
152 // Filter: WHERE value {wgsl_op} {filter_threshold}
153 if (data {wgsl_op} {filter_threshold}) {{
154 value = data;
155 }}
156 }}
157 shared_data[tid] = value;
158 workgroupBarrier();
159
160 // Parallel reduction (same as unfused SUM kernel)
161 var stride = 128u;
162 while (stride > 0u) {{
163 if (tid < stride && gid + stride < input_size) {{
164 shared_data[tid] += shared_data[tid + stride];
165 }}
166 workgroupBarrier();
167 stride = stride / 2u;
168 }}
169
170 // Write result
171 if (tid == 0u) {{
172 atomicAdd(&output[0], shared_data[0]);
173 }}
174}}
175"
176 )
177 }
178
179 pub fn compile_fused_filter_sum(
189 &self,
190 device: &wgpu::Device,
191 filter_threshold: i32,
192 filter_op: &str,
193 ) -> Arc<wgpu::ShaderModule> {
194 let cache_key = format!("filter_{filter_op}_{filter_threshold}_sum");
196
197 let shader_source = self.generate_fused_filter_sum(filter_threshold, filter_op);
199
200 self.cache.get_or_insert(&cache_key, device, &shader_source)
202 }
203
204 #[must_use]
206 pub fn cache_stats(&self) -> (usize, usize) {
207 self.cache.stats()
208 }
209}
210
211impl Default for JitCompiler {
212 fn default() -> Self {
213 Self::new()
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn test_shader_cache_new() {
223 let cache = ShaderCache::new();
224 let (size, _capacity) = cache.stats();
225 assert_eq!(size, 0);
226 }
227
228 #[test]
229 fn test_jit_compiler_new() {
230 let compiler = JitCompiler::new();
231 let (size, _capacity) = compiler.cache_stats();
232 assert_eq!(size, 0);
233 }
234
235 #[test]
236 fn test_generate_fused_filter_sum() {
237 let compiler = JitCompiler::new();
238
239 let greater_than = compiler.generate_fused_filter_sum(1000, "gt");
241 assert!(greater_than.contains("if (data > 1000)"));
242 assert!(greater_than.contains("fused_filter_sum"));
243
244 let less_than = compiler.generate_fused_filter_sum(500, "lt");
246 assert!(less_than.contains("if (data < 500)"));
247
248 let equals = compiler.generate_fused_filter_sum(42, "eq");
250 assert!(equals.contains("if (data == 42)"));
251 }
252
253 #[test]
254 fn test_shader_source_contains_fusion() {
255 let compiler = JitCompiler::new();
256 let shader = compiler.generate_fused_filter_sum(100, "gte");
257
258 assert!(shader.contains("@workgroup_size(256)"));
260 assert!(shader.contains("var<workgroup> shared_data"));
261 assert!(shader.contains("atomicAdd"));
262 assert!(shader.contains("workgroupBarrier"));
263
264 assert!(shader.contains("if (data >= 100)"));
266 }
267
268 #[test]
269 fn test_all_filter_operators() {
270 let compiler = JitCompiler::new();
271
272 let gte_shader = compiler.generate_fused_filter_sum(10, "gte");
274 assert!(gte_shader.contains("if (data >= 10)"));
275
276 let lte_shader = compiler.generate_fused_filter_sum(20, "lte");
277 assert!(lte_shader.contains("if (data <= 20)"));
278
279 let ne_shader = compiler.generate_fused_filter_sum(30, "ne");
280 assert!(ne_shader.contains("if (data != 30)"));
281
282 let unknown_shader = compiler.generate_fused_filter_sum(40, "unknown");
284 assert!(unknown_shader.contains("if (data > 40)"));
285 }
286
287 #[test]
288 fn test_shader_cache_default() {
289 let cache = ShaderCache::default();
290 let (size, _capacity) = cache.stats();
291 assert_eq!(size, 0);
292 }
293
294 #[test]
295 fn test_jit_compiler_default() {
296 let compiler = JitCompiler::default();
297 let (size, _capacity) = compiler.cache_stats();
298 assert_eq!(size, 0);
299 }
300
301 #[test]
302 fn test_cache_key_generation() {
303 let compiler = JitCompiler::new();
304
305 let shader1 = compiler.generate_fused_filter_sum(100, "gt");
307 let shader2 = compiler.generate_fused_filter_sum(200, "gt");
308 assert_ne!(shader1, shader2);
309
310 let shader3 = compiler.generate_fused_filter_sum(100, "lt");
312 assert_ne!(shader1, shader3);
313 }
314
315 #[test]
316 fn test_wgsl_syntax_valid() {
317 let compiler = JitCompiler::new();
318 let shader = compiler.generate_fused_filter_sum(999, "eq");
319
320 assert!(shader.contains("@group(0) @binding(0)"));
322 assert!(shader.contains("@group(0) @binding(1)"));
323 assert!(shader.contains("@compute @workgroup_size(256)"));
324 assert!(shader.contains("@builtin(global_invocation_id)"));
325 assert!(shader.contains("@builtin(local_invocation_id)"));
326 assert!(shader.contains("var<workgroup>"));
327 assert!(shader.contains("var<storage, read>"));
328 assert!(shader.contains("var<storage, read_write>"));
329 assert!(shader.contains("array<atomic<i32>>"));
330 }
331
332 #[test]
333 fn test_parallel_reduction_logic() {
334 let compiler = JitCompiler::new();
335 let shader = compiler.generate_fused_filter_sum(500, "gt");
336
337 assert!(shader.contains("var stride = 128u;"));
339 assert!(shader.contains("while (stride > 0u)"));
340 assert!(shader.contains("stride = stride / 2u;"));
341 assert!(shader.contains("if (tid == 0u)"));
342 }
343
344 #[test]
345 fn test_muda_elimination_comment() {
346 let compiler = JitCompiler::new();
347 let shader = compiler.generate_fused_filter_sum(100, "gt");
348
349 assert!(shader.contains("Eliminates intermediate buffer write"));
351 }
352}