Skip to main content

trueno_db/gpu/
jit.rs

1//! JIT WGSL Compiler for Kernel Fusion (CORE-003)
2//!
3//! Toyota Way: Muda elimination (waste of intermediate memory writes)
4//!
5//! This module provides runtime compilation of fused WGSL kernels.
6//! Phase 1 MVP: Simplified approach with template-based code generation.
7//!
8//! Example: Filter + SUM fusion
9//! - Non-fused: Filter → intermediate buffer → SUM (2 GPU passes, 1 memory write)
10//! - Fused: Filter + SUM in single pass (1 GPU pass, 0 intermediate writes)
11//!
12//! References:
13//! - Wu et al. (2012): Kernel fusion execution model
14//! - Neumann (2011): JIT compilation for queries
15//! - MonetDB/X100 (2005): Vectorized query execution
16
17use std::collections::HashMap;
18use std::sync::{Arc, Mutex};
19
20/// Shader compilation cache for JIT-compiled kernels
21///
22/// Caches compiled shaders by query signature to avoid recompilation.
23/// Thread-safe via Mutex for concurrent query execution.
24pub struct ShaderCache {
25    cache: Mutex<HashMap<String, Arc<wgpu::ShaderModule>>>,
26}
27
28impl ShaderCache {
29    /// Create a new shader cache
30    #[must_use]
31    pub fn new() -> Self {
32        Self { cache: Mutex::new(HashMap::new()) }
33    }
34
35    /// Get cached shader or insert new one
36    ///
37    /// # Arguments
38    /// * `key` - Query signature (e.g., `"filter_gt_1000_sum"`)
39    /// * `device` - GPU device for shader compilation
40    /// * `shader_source` - WGSL shader source code
41    ///
42    /// # Returns
43    /// Arc reference to compiled shader module (either cached or newly compiled)
44    ///
45    /// # Panics
46    /// Panics if the cache mutex is poisoned (should never happen in normal operation)
47    pub fn get_or_insert(
48        &self,
49        key: &str,
50        device: &wgpu::Device,
51        shader_source: &str,
52    ) -> Arc<wgpu::ShaderModule> {
53        let mut cache = self
54            .cache
55            .lock()
56            .expect("Shader cache mutex poisoned (should never happen in normal operation)");
57
58        if !cache.contains_key(key) {
59            let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
60                label: Some(key),
61                source: wgpu::ShaderSource::Wgsl(shader_source.into()),
62            });
63            cache.insert(key.to_string(), Arc::new(shader));
64        }
65
66        // Clone the Arc (cheap), not the ShaderModule
67        Arc::clone(cache.get(key).expect("Shader must exist in cache after insertion"))
68    }
69
70    /// Get cache statistics
71    ///
72    /// # Panics
73    /// Panics if the cache mutex is poisoned (should never happen in normal operation)
74    #[must_use]
75    pub fn stats(&self) -> (usize, usize) {
76        let cache = self
77            .cache
78            .lock()
79            .expect("Shader cache mutex poisoned (should never happen in normal operation)");
80        (cache.len(), cache.capacity())
81    }
82}
83
84impl Default for ShaderCache {
85    fn default() -> Self {
86        Self::new()
87    }
88}
89
90/// JIT WGSL compiler for kernel fusion
91///
92/// Phase 1 MVP: Template-based code generation for common patterns.
93/// Future: Full SQL AST → WGSL compilation in Phase 2.
94pub struct JitCompiler {
95    cache: ShaderCache,
96}
97
98impl JitCompiler {
99    /// Create a new JIT compiler with shader cache
100    #[must_use]
101    pub fn new() -> Self {
102        Self { cache: ShaderCache::new() }
103    }
104
105    /// Generate fused filter+sum kernel
106    ///
107    /// Fuses WHERE clause with SUM aggregation in single GPU pass.
108    ///
109    /// # Arguments
110    /// * `filter_threshold` - Threshold value for filter (e.g., WHERE value > 1000)
111    /// * `filter_op` - Filter operator ("gt", "lt", "eq", "gte", "lte")
112    ///
113    /// # Returns
114    /// WGSL shader source code for fused kernel
115    ///
116    /// # Example
117    /// ```ignore
118    /// let shader = compiler.generate_fused_filter_sum(1000, "gt");
119    /// // Generates: WHERE value > 1000, SUM(value) in single pass
120    /// ```
121    #[must_use]
122    pub fn generate_fused_filter_sum(&self, filter_threshold: i32, filter_op: &str) -> String {
123        // Convert operator to WGSL
124        let wgsl_op = match filter_op {
125            "lt" => "<",
126            "eq" => "==",
127            "gte" => ">=",
128            "lte" => "<=",
129            "ne" => "!=",
130            _ => ">", // Default to greater-than (handles "gt" and unknown ops)
131        };
132
133        format!(
134            r"
135@group(0) @binding(0) var<storage, read> input: array<i32>;
136@group(0) @binding(1) var<storage, read_write> output: array<atomic<i32>>;
137
138var<workgroup> shared_data: array<i32, 256>;
139
140@compute @workgroup_size(256)
141fn fused_filter_sum(@builtin(global_invocation_id) global_id: vec3<u32>,
142                    @builtin(local_invocation_id) local_id: vec3<u32>) {{
143    let tid = local_id.x;
144    let gid = global_id.x;
145    let input_size = arrayLength(&input);
146
147    // Fused filter + load: Apply filter predicate during load
148    // Eliminates intermediate buffer write (Muda elimination)
149    var value: i32 = 0;
150    if (gid < input_size) {{
151        let data = input[gid];
152        // Filter: WHERE value {wgsl_op} {filter_threshold}
153        if (data {wgsl_op} {filter_threshold}) {{
154            value = data;
155        }}
156    }}
157    shared_data[tid] = value;
158    workgroupBarrier();
159
160    // Parallel reduction (same as unfused SUM kernel)
161    var stride = 128u;
162    while (stride > 0u) {{
163        if (tid < stride && gid + stride < input_size) {{
164            shared_data[tid] += shared_data[tid + stride];
165        }}
166        workgroupBarrier();
167        stride = stride / 2u;
168    }}
169
170    // Write result
171    if (tid == 0u) {{
172        atomicAdd(&output[0], shared_data[0]);
173    }}
174}}
175"
176        )
177    }
178
179    /// Compile and cache fused filter+sum kernel
180    ///
181    /// # Arguments
182    /// * `device` - GPU device for compilation
183    /// * `filter_threshold` - Filter threshold value
184    /// * `filter_op` - Filter operator
185    ///
186    /// # Returns
187    /// Arc reference to compiled shader module (cached for reuse)
188    pub fn compile_fused_filter_sum(
189        &self,
190        device: &wgpu::Device,
191        filter_threshold: i32,
192        filter_op: &str,
193    ) -> Arc<wgpu::ShaderModule> {
194        // Generate cache key from query signature
195        let cache_key = format!("filter_{filter_op}_{filter_threshold}_sum");
196
197        // Generate WGSL source
198        let shader_source = self.generate_fused_filter_sum(filter_threshold, filter_op);
199
200        // Get from cache or compile
201        self.cache.get_or_insert(&cache_key, device, &shader_source)
202    }
203
204    /// Get cache statistics (size, capacity)
205    #[must_use]
206    pub fn cache_stats(&self) -> (usize, usize) {
207        self.cache.stats()
208    }
209}
210
211impl Default for JitCompiler {
212    fn default() -> Self {
213        Self::new()
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn test_shader_cache_new() {
223        let cache = ShaderCache::new();
224        let (size, _capacity) = cache.stats();
225        assert_eq!(size, 0);
226    }
227
228    #[test]
229    fn test_jit_compiler_new() {
230        let compiler = JitCompiler::new();
231        let (size, _capacity) = compiler.cache_stats();
232        assert_eq!(size, 0);
233    }
234
235    #[test]
236    fn test_generate_fused_filter_sum() {
237        let compiler = JitCompiler::new();
238
239        // Test greater-than filter
240        let greater_than = compiler.generate_fused_filter_sum(1000, "gt");
241        assert!(greater_than.contains("if (data > 1000)"));
242        assert!(greater_than.contains("fused_filter_sum"));
243
244        // Test less-than filter
245        let less_than = compiler.generate_fused_filter_sum(500, "lt");
246        assert!(less_than.contains("if (data < 500)"));
247
248        // Test equals filter
249        let equals = compiler.generate_fused_filter_sum(42, "eq");
250        assert!(equals.contains("if (data == 42)"));
251    }
252
253    #[test]
254    fn test_shader_source_contains_fusion() {
255        let compiler = JitCompiler::new();
256        let shader = compiler.generate_fused_filter_sum(100, "gte");
257
258        // Verify it contains key fusion components
259        assert!(shader.contains("@workgroup_size(256)"));
260        assert!(shader.contains("var<workgroup> shared_data"));
261        assert!(shader.contains("atomicAdd"));
262        assert!(shader.contains("workgroupBarrier"));
263
264        // Verify filter is inline (fused)
265        assert!(shader.contains("if (data >= 100)"));
266    }
267
268    #[test]
269    fn test_all_filter_operators() {
270        let compiler = JitCompiler::new();
271
272        // Test all supported operators
273        let gte_shader = compiler.generate_fused_filter_sum(10, "gte");
274        assert!(gte_shader.contains("if (data >= 10)"));
275
276        let lte_shader = compiler.generate_fused_filter_sum(20, "lte");
277        assert!(lte_shader.contains("if (data <= 20)"));
278
279        let ne_shader = compiler.generate_fused_filter_sum(30, "ne");
280        assert!(ne_shader.contains("if (data != 30)"));
281
282        // Test unknown operator defaults to >
283        let unknown_shader = compiler.generate_fused_filter_sum(40, "unknown");
284        assert!(unknown_shader.contains("if (data > 40)"));
285    }
286
287    #[test]
288    fn test_shader_cache_default() {
289        let cache = ShaderCache::default();
290        let (size, _capacity) = cache.stats();
291        assert_eq!(size, 0);
292    }
293
294    #[test]
295    fn test_jit_compiler_default() {
296        let compiler = JitCompiler::default();
297        let (size, _capacity) = compiler.cache_stats();
298        assert_eq!(size, 0);
299    }
300
301    #[test]
302    fn test_cache_key_generation() {
303        let compiler = JitCompiler::new();
304
305        // Different thresholds should generate different shaders
306        let shader1 = compiler.generate_fused_filter_sum(100, "gt");
307        let shader2 = compiler.generate_fused_filter_sum(200, "gt");
308        assert_ne!(shader1, shader2);
309
310        // Different operators should generate different shaders
311        let shader3 = compiler.generate_fused_filter_sum(100, "lt");
312        assert_ne!(shader1, shader3);
313    }
314
315    #[test]
316    fn test_wgsl_syntax_valid() {
317        let compiler = JitCompiler::new();
318        let shader = compiler.generate_fused_filter_sum(999, "eq");
319
320        // Verify critical WGSL syntax elements
321        assert!(shader.contains("@group(0) @binding(0)"));
322        assert!(shader.contains("@group(0) @binding(1)"));
323        assert!(shader.contains("@compute @workgroup_size(256)"));
324        assert!(shader.contains("@builtin(global_invocation_id)"));
325        assert!(shader.contains("@builtin(local_invocation_id)"));
326        assert!(shader.contains("var<workgroup>"));
327        assert!(shader.contains("var<storage, read>"));
328        assert!(shader.contains("var<storage, read_write>"));
329        assert!(shader.contains("array<atomic<i32>>"));
330    }
331
332    #[test]
333    fn test_parallel_reduction_logic() {
334        let compiler = JitCompiler::new();
335        let shader = compiler.generate_fused_filter_sum(500, "gt");
336
337        // Verify parallel reduction pattern (Harris 2007)
338        assert!(shader.contains("var stride = 128u;"));
339        assert!(shader.contains("while (stride > 0u)"));
340        assert!(shader.contains("stride = stride / 2u;"));
341        assert!(shader.contains("if (tid == 0u)"));
342    }
343
344    #[test]
345    fn test_muda_elimination_comment() {
346        let compiler = JitCompiler::new();
347        let shader = compiler.generate_fused_filter_sum(100, "gt");
348
349        // Verify Toyota Way: Muda elimination comment exists
350        assert!(shader.contains("Eliminates intermediate buffer write"));
351    }
352}