1use crate::{Error, Result};
18use arrow::array::{Array, Float32Array, Int32Array};
19use wgpu;
20use wgpu::util::DeviceExt;
21
22pub mod jit;
23pub mod kernels;
24pub mod multigpu;
25
26pub struct GpuEngine {
28 pub device: wgpu::Device,
30 pub queue: wgpu::Queue,
32 jit: jit::JitCompiler,
34}
35
36impl GpuEngine {
37 pub async fn new() -> Result<Self> {
42 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
44 backends: wgpu::Backends::all(),
45 ..Default::default()
46 });
47
48 let adapter = instance
49 .request_adapter(&wgpu::RequestAdapterOptions {
50 power_preference: wgpu::PowerPreference::HighPerformance,
51 compatible_surface: None,
52 force_fallback_adapter: false,
53 })
54 .await
55 .ok_or_else(|| Error::GpuInitFailed("No GPU adapter found".to_string()))?;
56
57 let (device, queue) = adapter
59 .request_device(
60 &wgpu::DeviceDescriptor {
61 label: Some("Trueno-DB GPU Device"),
62 required_features: wgpu::Features::empty(),
63 required_limits: wgpu::Limits::default(),
64 memory_hints: wgpu::MemoryHints::default(),
65 },
66 None,
67 )
68 .await
69 .map_err(|e| Error::GpuInitFailed(format!("Failed to create device: {e}")))?;
70
71 Ok(Self { device, queue, jit: jit::JitCompiler::new() })
72 }
73
74 pub async fn sum_i32(&self, data: &Int32Array) -> Result<i32> {
85 kernels::sum_i32(&self.device, &self.queue, data).await
86 }
87
88 pub async fn sum_f32(&self, data: &Float32Array) -> Result<f32> {
93 kernels::sum_f32(&self.device, &self.queue, data).await
94 }
95
96 pub async fn count(&self, data: &dyn Array) -> Result<usize> {
101 kernels::count(&self.device, &self.queue, data).await
102 }
103
104 pub async fn min_i32(&self, data: &Int32Array) -> Result<i32> {
109 kernels::min_i32(&self.device, &self.queue, data).await
110 }
111
112 pub async fn max_i32(&self, data: &Int32Array) -> Result<i32> {
117 kernels::max_i32(&self.device, &self.queue, data).await
118 }
119
120 #[allow(clippy::cast_precision_loss)]
125 pub async fn avg_f32(&self, data: &Float32Array) -> Result<f32> {
126 let sum = self.sum_f32(data).await?;
127 let count = self.count(data).await?;
128 if count == 0 {
129 Ok(0.0)
130 } else {
131 Ok(sum / count as f32)
132 }
133 }
134
135 #[allow(clippy::too_many_lines)]
157 #[allow(clippy::cast_possible_truncation)]
158 pub async fn fused_filter_sum(
159 &self,
160 data: &Int32Array,
161 filter_threshold: i32,
162 filter_op: &str,
163 ) -> Result<i32> {
164 let shader_module =
166 self.jit.compile_fused_filter_sum(&self.device, filter_threshold, filter_op);
167
168 let input_data: Vec<i32> = data.values().to_vec();
170 let input_size = input_data.len();
171
172 if input_size == 0 {
173 return Ok(0);
174 }
175
176 let input_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
178 label: Some("Fused Filter+Sum Input"),
179 contents: bytemuck::cast_slice(&input_data),
180 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
181 });
182
183 let output_buffer = self.device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
184 label: Some("Fused Filter+Sum Output"),
185 contents: bytemuck::cast_slice(&[0i32]),
186 usage: wgpu::BufferUsages::STORAGE
187 | wgpu::BufferUsages::COPY_SRC
188 | wgpu::BufferUsages::COPY_DST,
189 });
190
191 let bind_group_layout =
193 self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
194 label: Some("Fused Filter+Sum Bind Group Layout"),
195 entries: &[
196 wgpu::BindGroupLayoutEntry {
197 binding: 0,
198 visibility: wgpu::ShaderStages::COMPUTE,
199 ty: wgpu::BindingType::Buffer {
200 ty: wgpu::BufferBindingType::Storage { read_only: true },
201 has_dynamic_offset: false,
202 min_binding_size: None,
203 },
204 count: None,
205 },
206 wgpu::BindGroupLayoutEntry {
207 binding: 1,
208 visibility: wgpu::ShaderStages::COMPUTE,
209 ty: wgpu::BindingType::Buffer {
210 ty: wgpu::BufferBindingType::Storage { read_only: false },
211 has_dynamic_offset: false,
212 min_binding_size: None,
213 },
214 count: None,
215 },
216 ],
217 });
218
219 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
221 label: Some("Fused Filter+Sum Bind Group"),
222 layout: &bind_group_layout,
223 entries: &[
224 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
225 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
226 ],
227 });
228
229 let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
231 label: Some("Fused Filter+Sum Pipeline Layout"),
232 bind_group_layouts: &[&bind_group_layout],
233 push_constant_ranges: &[],
234 });
235
236 let compute_pipeline =
237 self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
238 label: Some("Fused Filter+Sum Pipeline"),
239 layout: Some(&pipeline_layout),
240 module: &shader_module,
241 entry_point: "fused_filter_sum",
242 compilation_options: wgpu::PipelineCompilationOptions::default(),
243 cache: None,
244 });
245
246 let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
248 label: Some("Fused Filter+Sum Encoder"),
249 });
250
251 {
252 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
253 label: Some("Fused Filter+Sum Pass"),
254 timestamp_writes: None,
255 });
256 compute_pass.set_pipeline(&compute_pipeline);
257 compute_pass.set_bind_group(0, &bind_group, &[]);
258
259 let workgroup_count = (input_size as u32).div_ceil(256);
261 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
262 }
263
264 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
266 label: Some("Fused Filter+Sum Staging Buffer"),
267 size: 4,
268 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
269 mapped_at_creation: false,
270 });
271
272 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, 4);
273
274 self.queue.submit(Some(encoder.finish()));
276
277 let buffer_slice = staging_buffer.slice(..);
279 let (tx, rx) = futures_intrusive::channel::shared::oneshot_channel();
280 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
281 tx.send(result).ok();
282 });
283
284 self.device.poll(wgpu::Maintain::Wait);
285
286 rx.receive()
287 .await
288 .ok_or_else(|| Error::Other("Failed to receive buffer map result".to_string()))?
289 .map_err(|e| Error::Other(format!("Buffer mapping failed: {e}")))?;
290
291 let data_view = buffer_slice.get_mapped_range();
292 let result = i32::from_le_bytes([data_view[0], data_view[1], data_view[2], data_view[3]]);
293
294 drop(data_view);
295 staging_buffer.unmap();
296
297 Ok(result)
298 }
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304 use arrow::array::Int32Array;
305
306 #[tokio::test]
307 async fn test_gpu_init() {
308 match GpuEngine::new().await {
310 Ok(_engine) => {
311 }
313 Err(e) => {
314 eprintln!("GPU initialization failed (expected on CI): {e}");
316 }
317 }
318 }
319
320 #[tokio::test]
321 async fn test_gpu_sum_basic() {
322 let Ok(engine) = GpuEngine::new().await else {
323 eprintln!("Skipping GPU test (no GPU available)");
324 return;
325 };
326
327 let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
328 let result = engine.sum_i32(&data).await.unwrap();
329 assert_eq!(result, 15);
330 }
331
332 #[tokio::test]
333 async fn test_gpu_sum_empty() {
334 let Ok(engine) = GpuEngine::new().await else {
335 eprintln!("Skipping GPU test (no GPU available)");
336 return;
337 };
338
339 let data = Int32Array::from(vec![] as Vec<i32>);
340 let result = engine.sum_i32(&data).await.unwrap();
341 assert_eq!(result, 0);
342 }
343
344 #[tokio::test]
345 async fn test_gpu_min_i32() {
346 let Ok(engine) = GpuEngine::new().await else {
347 eprintln!("Skipping GPU test (no GPU available)");
348 return;
349 };
350
351 let data = Int32Array::from(vec![5, 2, 8, 1, 9]);
352 let result = engine.min_i32(&data).await.unwrap();
353 assert_eq!(result, 1);
354 }
355
356 #[tokio::test]
357 async fn test_gpu_min_empty() {
358 let Ok(engine) = GpuEngine::new().await else {
359 eprintln!("Skipping GPU test (no GPU available)");
360 return;
361 };
362
363 let data = Int32Array::from(vec![] as Vec<i32>);
364 let result = engine.min_i32(&data).await.unwrap();
365 assert_eq!(result, i32::MAX);
366 }
367
368 #[tokio::test]
369 async fn test_gpu_max_i32() {
370 let Ok(engine) = GpuEngine::new().await else {
371 eprintln!("Skipping GPU test (no GPU available)");
372 return;
373 };
374
375 let data = Int32Array::from(vec![5, 2, 8, 1, 9]);
376 let result = engine.max_i32(&data).await.unwrap();
377 assert_eq!(result, 9);
378 }
379
380 #[tokio::test]
381 async fn test_gpu_max_empty() {
382 let Ok(engine) = GpuEngine::new().await else {
383 eprintln!("Skipping GPU test (no GPU available)");
384 return;
385 };
386
387 let data = Int32Array::from(vec![] as Vec<i32>);
388 let result = engine.max_i32(&data).await.unwrap();
389 assert_eq!(result, i32::MIN);
390 }
391
392 #[tokio::test]
393 async fn test_gpu_count() {
394 let Ok(engine) = GpuEngine::new().await else {
395 eprintln!("Skipping GPU test (no GPU available)");
396 return;
397 };
398
399 let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
400 let result = engine.count(&data).await.unwrap();
401 assert_eq!(result, 5);
402 }
403
404 #[tokio::test]
405 async fn test_gpu_sum_f32_not_implemented() {
406 let Ok(engine) = GpuEngine::new().await else {
407 eprintln!("Skipping GPU test (no GPU available)");
408 return;
409 };
410
411 let data = Float32Array::from(vec![1.0, 2.0, 3.0]);
412 let result = engine.sum_f32(&data).await;
413 assert!(result.is_err());
414 assert!(result.unwrap_err().to_string().contains("not yet implemented"));
415 }
416
417 #[tokio::test]
418 async fn test_gpu_avg_f32_not_implemented() {
419 let Ok(engine) = GpuEngine::new().await else {
420 eprintln!("Skipping GPU test (no GPU available)");
421 return;
422 };
423
424 let data = Float32Array::from(vec![2.0, 4.0, 6.0]);
425 let result = engine.avg_f32(&data).await;
426 assert!(result.is_err());
428 }
429
430 #[tokio::test]
431 async fn test_gpu_fused_filter_sum_gt() {
432 let Ok(engine) = GpuEngine::new().await else {
433 eprintln!("Skipping GPU test (no GPU available)");
434 return;
435 };
436
437 let data = Int32Array::from(vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
441 let result = engine.fused_filter_sum(&data, 5, "gt").await.unwrap();
442 assert_eq!(result, 40);
443 }
444
445 #[tokio::test]
446 async fn test_gpu_fused_filter_sum_lt() {
447 let Ok(engine) = GpuEngine::new().await else {
448 eprintln!("Skipping GPU test (no GPU available)");
449 return;
450 };
451
452 let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
456 let result = engine.fused_filter_sum(&data, 4, "lt").await.unwrap();
457 assert_eq!(result, 6);
458 }
459
460 #[tokio::test]
461 async fn test_gpu_fused_filter_sum_eq() {
462 let Ok(engine) = GpuEngine::new().await else {
463 eprintln!("Skipping GPU test (no GPU available)");
464 return;
465 };
466
467 let data = Int32Array::from(vec![1, 5, 5, 3, 5]);
471 let result = engine.fused_filter_sum(&data, 5, "eq").await.unwrap();
472 assert_eq!(result, 15);
473 }
474
475 #[tokio::test]
476 async fn test_gpu_fused_filter_sum_empty() {
477 let Ok(engine) = GpuEngine::new().await else {
478 eprintln!("Skipping GPU test (no GPU available)");
479 return;
480 };
481
482 let data = Int32Array::from(vec![] as Vec<i32>);
483 let result = engine.fused_filter_sum(&data, 5, "gt").await.unwrap();
484 assert_eq!(result, 0);
485 }
486
487 #[tokio::test]
488 async fn test_gpu_fused_filter_sum_no_matches() {
489 let Ok(engine) = GpuEngine::new().await else {
490 eprintln!("Skipping GPU test (no GPU available)");
491 return;
492 };
493
494 let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
496 let result = engine.fused_filter_sum(&data, 100, "gt").await.unwrap();
497 assert_eq!(result, 0);
498 }
499}