Skip to main content

trueno_db/gpu/
mod.rs

1//! GPU compute backend using wgpu (WebGPU)
2//!
3//! Toyota Way Principles:
4//! - Muda elimination: GPU only when compute > 5x transfer time
5//! - Genchi Genbutsu: Empirical benchmarks prove 50-100x speedups
6//!
7//! Architecture:
8//! - WGSL compute shaders for parallel reduction
9//! - Workgroup size: 256 threads (GPU warp size optimization)
10//! - Two-stage reduction: workgroup-local + global
11//!
12//! References:
13//! - `HeavyDB` (2017): GPU aggregation patterns
14//! - Harris (2007): Optimizing parallel reduction in CUDA
15//! - Leis et al. (2014): Morsel-driven parallelism
16
17use crate::{Error, Result};
18use arrow::array::{Array, Float32Array, Int32Array};
19use wgpu;
20use wgpu::util::DeviceExt;
21
22pub mod jit;
23pub mod kernels;
24pub mod multigpu;
25
26/// GPU compute engine for aggregations
27pub struct GpuEngine {
28    /// GPU device handle (public for benchmarking)
29    pub device: wgpu::Device,
30    /// GPU command queue (public for benchmarking)
31    pub queue: wgpu::Queue,
32    /// JIT compiler for kernel fusion
33    jit: jit::JitCompiler,
34}
35
36impl GpuEngine {
37    /// Initialize GPU engine
38    ///
39    /// # Errors
40    /// Returns error if GPU initialization fails (no GPU available, driver issues, etc.)
41    pub async fn new() -> Result<Self> {
42        // Request GPU adapter
43        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
44            backends: wgpu::Backends::all(),
45            ..Default::default()
46        });
47
48        let adapter = instance
49            .request_adapter(&wgpu::RequestAdapterOptions {
50                power_preference: wgpu::PowerPreference::HighPerformance,
51                compatible_surface: None,
52                force_fallback_adapter: false,
53            })
54            .await
55            .ok_or_else(|| Error::GpuInitFailed("No GPU adapter found".to_string()))?;
56
57        // Request device and queue
58        let (device, queue) = adapter
59            .request_device(
60                &wgpu::DeviceDescriptor {
61                    label: Some("Trueno-DB GPU Device"),
62                    required_features: wgpu::Features::empty(),
63                    required_limits: wgpu::Limits::default(),
64                    memory_hints: wgpu::MemoryHints::default(),
65                },
66                None,
67            )
68            .await
69            .map_err(|e| Error::GpuInitFailed(format!("Failed to create device: {e}")))?;
70
71        Ok(Self { device, queue, jit: jit::JitCompiler::new() })
72    }
73
74    /// Execute SUM aggregation on GPU
75    ///
76    /// # Arguments
77    /// * `data` - Input array (Int32 or Float32)
78    ///
79    /// # Returns
80    /// Sum of all elements
81    ///
82    /// # Errors
83    /// Returns error if GPU execution fails
84    pub async fn sum_i32(&self, data: &Int32Array) -> Result<i32> {
85        kernels::sum_i32(&self.device, &self.queue, data).await
86    }
87
88    /// Execute SUM aggregation on GPU (f32)
89    ///
90    /// # Errors
91    /// Returns error if GPU execution fails
92    pub async fn sum_f32(&self, data: &Float32Array) -> Result<f32> {
93        kernels::sum_f32(&self.device, &self.queue, data).await
94    }
95
96    /// Execute COUNT aggregation on GPU
97    ///
98    /// # Errors
99    /// Returns error if GPU execution fails
100    pub async fn count(&self, data: &dyn Array) -> Result<usize> {
101        kernels::count(&self.device, &self.queue, data).await
102    }
103
104    /// Execute MIN aggregation on GPU
105    ///
106    /// # Errors
107    /// Returns error if GPU execution fails
108    pub async fn min_i32(&self, data: &Int32Array) -> Result<i32> {
109        kernels::min_i32(&self.device, &self.queue, data).await
110    }
111
112    /// Execute MAX aggregation on GPU
113    ///
114    /// # Errors
115    /// Returns error if GPU execution fails
116    pub async fn max_i32(&self, data: &Int32Array) -> Result<i32> {
117        kernels::max_i32(&self.device, &self.queue, data).await
118    }
119
120    /// Execute AVG aggregation on GPU (reuses sum + count)
121    ///
122    /// # Errors
123    /// Returns error if GPU execution fails
124    #[allow(clippy::cast_precision_loss)]
125    pub async fn avg_f32(&self, data: &Float32Array) -> Result<f32> {
126        let sum = self.sum_f32(data).await?;
127        let count = self.count(data).await?;
128        if count == 0 {
129            Ok(0.0)
130        } else {
131            Ok(sum / count as f32)
132        }
133    }
134
135    /// Execute fused filter+sum aggregation on GPU (JIT-compiled kernel)
136    ///
137    /// Toyota Way: Muda elimination - fuses filter and sum in single pass,
138    /// eliminating intermediate buffer write.
139    ///
140    /// # Arguments
141    /// * `data` - Input array (Int32)
142    /// * `filter_threshold` - Filter threshold value (e.g., WHERE value > 1000)
143    /// * `filter_op` - Filter operator ("gt", "lt", "eq", "gte", "lte", "ne")
144    ///
145    /// # Returns
146    /// Sum of filtered elements
147    ///
148    /// # Errors
149    /// Returns error if GPU execution fails
150    ///
151    /// # Example
152    /// ```ignore
153    /// // Equivalent to: SELECT SUM(value) FROM data WHERE value > 1000
154    /// let result = engine.fused_filter_sum(&data, 1000, "gt").await?;
155    /// ```
156    #[allow(clippy::too_many_lines)]
157    #[allow(clippy::cast_possible_truncation)]
158    pub async fn fused_filter_sum(
159        &self,
160        data: &Int32Array,
161        filter_threshold: i32,
162        filter_op: &str,
163    ) -> Result<i32> {
164        // JIT compile the fused kernel (cached automatically)
165        let shader_module =
166            self.jit.compile_fused_filter_sum(&self.device, filter_threshold, filter_op);
167
168        // Prepare input data
169        let input_data: Vec<i32> = data.values().to_vec();
170        let input_size = input_data.len();
171
172        if input_size == 0 {
173            return Ok(0);
174        }
175
176        // Create GPU buffers
177        let input_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
178            label: Some("Fused Filter+Sum Input"),
179            contents: bytemuck::cast_slice(&input_data),
180            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
181        });
182
183        let output_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
184            label: Some("Fused Filter+Sum Output"),
185            contents: bytemuck::cast_slice(&[0i32]),
186            usage: wgpu::BufferUsages::STORAGE
187                | wgpu::BufferUsages::COPY_SRC
188                | wgpu::BufferUsages::COPY_DST,
189        });
190
191        // Create bind group layout
192        let bind_group_layout =
193            self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
194                label: Some("Fused Filter+Sum Bind Group Layout"),
195                entries: &[
196                    wgpu::BindGroupLayoutEntry {
197                        binding: 0,
198                        visibility: wgpu::ShaderStages::COMPUTE,
199                        ty: wgpu::BindingType::Buffer {
200                            ty: wgpu::BufferBindingType::Storage { read_only: true },
201                            has_dynamic_offset: false,
202                            min_binding_size: None,
203                        },
204                        count: None,
205                    },
206                    wgpu::BindGroupLayoutEntry {
207                        binding: 1,
208                        visibility: wgpu::ShaderStages::COMPUTE,
209                        ty: wgpu::BindingType::Buffer {
210                            ty: wgpu::BufferBindingType::Storage { read_only: false },
211                            has_dynamic_offset: false,
212                            min_binding_size: None,
213                        },
214                        count: None,
215                    },
216                ],
217            });
218
219        // Create bind group
220        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
221            label: Some("Fused Filter+Sum Bind Group"),
222            layout: &bind_group_layout,
223            entries: &[
224                wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
225                wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
226            ],
227        });
228
229        // Create compute pipeline
230        let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
231            label: Some("Fused Filter+Sum Pipeline Layout"),
232            bind_group_layouts: &[&bind_group_layout],
233            push_constant_ranges: &[],
234        });
235
236        let compute_pipeline =
237            self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
238                label: Some("Fused Filter+Sum Pipeline"),
239                layout: Some(&pipeline_layout),
240                module: &shader_module,
241                entry_point: "fused_filter_sum",
242                compilation_options: wgpu::PipelineCompilationOptions::default(),
243                cache: None,
244            });
245
246        // Create command encoder and execute
247        let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
248            label: Some("Fused Filter+Sum Encoder"),
249        });
250
251        {
252            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253                label: Some("Fused Filter+Sum Pass"),
254                timestamp_writes: None,
255            });
256            compute_pass.set_pipeline(&compute_pipeline);
257            compute_pass.set_bind_group(0, &bind_group, &[]);
258
259            // Dispatch workgroups (256 threads per workgroup)
260            let workgroup_count = (input_size as u32).div_ceil(256);
261            compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
262        }
263
264        // Copy output to staging buffer
265        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
266            label: Some("Fused Filter+Sum Staging Buffer"),
267            size: 4,
268            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
269            mapped_at_creation: false,
270        });
271
272        encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, 4);
273
274        // Submit commands
275        self.queue.submit(Some(encoder.finish()));
276
277        // Read result
278        let buffer_slice = staging_buffer.slice(..);
279        let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
280        buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
281            tx.send(result).ok();
282        });
283
284        self.device.poll(wgpu::Maintain::Wait);
285
286        rx.receive()
287            .await
288            .ok_or_else(|| Error::Other("Failed to receive buffer map result".to_string()))?
289            .map_err(|e| Error::Other(format!("Buffer mapping failed: {e}")))?;
290
291        let data_view = buffer_slice.get_mapped_range();
292        let result = i32::from_le_bytes([data_view[0], data_view[1], data_view[2], data_view[3]]);
293
294        drop(data_view);
295        staging_buffer.unmap();
296
297        Ok(result)
298    }
299}
300
301#[cfg(test)]
302mod tests {
303    use super::*;
304    use arrow::array::Int32Array;
305
306    #[tokio::test]
307    async fn test_gpu_init() {
308        // This test may fail on machines without GPU
309        match GpuEngine::new().await {
310            Ok(_engine) => {
311                // GPU initialization succeeded
312            }
313            Err(e) => {
314                // Expected on machines without GPU
315                eprintln!("GPU initialization failed (expected on CI): {e}");
316            }
317        }
318    }
319
320    #[tokio::test]
321    async fn test_gpu_sum_basic() {
322        let Ok(engine) = GpuEngine::new().await else {
323            eprintln!("Skipping GPU test (no GPU available)");
324            return;
325        };
326
327        let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
328        let result = engine.sum_i32(&data).await.unwrap();
329        assert_eq!(result, 15);
330    }
331
332    #[tokio::test]
333    async fn test_gpu_sum_empty() {
334        let Ok(engine) = GpuEngine::new().await else {
335            eprintln!("Skipping GPU test (no GPU available)");
336            return;
337        };
338
339        let data = Int32Array::from(vec![] as Vec<i32>);
340        let result = engine.sum_i32(&data).await.unwrap();
341        assert_eq!(result, 0);
342    }
343
344    #[tokio::test]
345    async fn test_gpu_min_i32() {
346        let Ok(engine) = GpuEngine::new().await else {
347            eprintln!("Skipping GPU test (no GPU available)");
348            return;
349        };
350
351        let data = Int32Array::from(vec![5, 2, 8, 1, 9]);
352        let result = engine.min_i32(&data).await.unwrap();
353        assert_eq!(result, 1);
354    }
355
356    #[tokio::test]
357    async fn test_gpu_min_empty() {
358        let Ok(engine) = GpuEngine::new().await else {
359            eprintln!("Skipping GPU test (no GPU available)");
360            return;
361        };
362
363        let data = Int32Array::from(vec![] as Vec<i32>);
364        let result = engine.min_i32(&data).await.unwrap();
365        assert_eq!(result, i32::MAX);
366    }
367
368    #[tokio::test]
369    async fn test_gpu_max_i32() {
370        let Ok(engine) = GpuEngine::new().await else {
371            eprintln!("Skipping GPU test (no GPU available)");
372            return;
373        };
374
375        let data = Int32Array::from(vec![5, 2, 8, 1, 9]);
376        let result = engine.max_i32(&data).await.unwrap();
377        assert_eq!(result, 9);
378    }
379
380    #[tokio::test]
381    async fn test_gpu_max_empty() {
382        let Ok(engine) = GpuEngine::new().await else {
383            eprintln!("Skipping GPU test (no GPU available)");
384            return;
385        };
386
387        let data = Int32Array::from(vec![] as Vec<i32>);
388        let result = engine.max_i32(&data).await.unwrap();
389        assert_eq!(result, i32::MIN);
390    }
391
392    #[tokio::test]
393    async fn test_gpu_count() {
394        let Ok(engine) = GpuEngine::new().await else {
395            eprintln!("Skipping GPU test (no GPU available)");
396            return;
397        };
398
399        let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
400        let result = engine.count(&data).await.unwrap();
401        assert_eq!(result, 5);
402    }
403
404    #[tokio::test]
405    async fn test_gpu_sum_f32_not_implemented() {
406        let Ok(engine) = GpuEngine::new().await else {
407            eprintln!("Skipping GPU test (no GPU available)");
408            return;
409        };
410
411        let data = Float32Array::from(vec![1.0, 2.0, 3.0]);
412        let result = engine.sum_f32(&data).await;
413        assert!(result.is_err());
414        assert!(result.unwrap_err().to_string().contains("not yet implemented"));
415    }
416
417    #[tokio::test]
418    async fn test_gpu_avg_f32_not_implemented() {
419        let Ok(engine) = GpuEngine::new().await else {
420            eprintln!("Skipping GPU test (no GPU available)");
421            return;
422        };
423
424        let data = Float32Array::from(vec![2.0, 4.0, 6.0]);
425        let result = engine.avg_f32(&data).await;
426        // avg_f32 calls sum_f32 which returns error
427        assert!(result.is_err());
428    }
429
430    #[tokio::test]
431    async fn test_gpu_fused_filter_sum_gt() {
432        let Ok(engine) = GpuEngine::new().await else {
433            eprintln!("Skipping GPU test (no GPU available)");
434            return;
435        };
436
437        // Data: [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
438        // Filter: value > 5
439        // Expected: 6 + 7 + 8 + 9 + 10 = 40
440        let data = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
441        let result = engine.fused_filter_sum(&data, 5, "gt").await.unwrap();
442        assert_eq!(result, 40);
443    }
444
445    #[tokio::test]
446    async fn test_gpu_fused_filter_sum_lt() {
447        let Ok(engine) = GpuEngine::new().await else {
448            eprintln!("Skipping GPU test (no GPU available)");
449            return;
450        };
451
452        // Data: [1, 2, 3, 4, 5]
453        // Filter: value < 4
454        // Expected: 1 + 2 + 3 = 6
455        let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
456        let result = engine.fused_filter_sum(&data, 4, "lt").await.unwrap();
457        assert_eq!(result, 6);
458    }
459
460    #[tokio::test]
461    async fn test_gpu_fused_filter_sum_eq() {
462        let Ok(engine) = GpuEngine::new().await else {
463            eprintln!("Skipping GPU test (no GPU available)");
464            return;
465        };
466
467        // Data: [1, 5, 5, 3, 5]
468        // Filter: value == 5
469        // Expected: 5 + 5 + 5 = 15
470        let data = Int32Array::from(vec![1, 5, 5, 3, 5]);
471        let result = engine.fused_filter_sum(&data, 5, "eq").await.unwrap();
472        assert_eq!(result, 15);
473    }
474
475    #[tokio::test]
476    async fn test_gpu_fused_filter_sum_empty() {
477        let Ok(engine) = GpuEngine::new().await else {
478            eprintln!("Skipping GPU test (no GPU available)");
479            return;
480        };
481
482        let data = Int32Array::from(vec![] as Vec<i32>);
483        let result = engine.fused_filter_sum(&data, 5, "gt").await.unwrap();
484        assert_eq!(result, 0);
485    }
486
487    #[tokio::test]
488    async fn test_gpu_fused_filter_sum_no_matches() {
489        let Ok(engine) = GpuEngine::new().await else {
490            eprintln!("Skipping GPU test (no GPU available)");
491            return;
492        };
493
494        // All values < 100, so filter passes nothing
495        let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
496        let result = engine.fused_filter_sum(&data, 100, "gt").await.unwrap();
497        assert_eq!(result, 0);
498    }
499}