1use crate::{Error, Result};
11use arrow::array::{Array, Float32Array, Int32Array};
12use wgpu;
13use wgpu::util::DeviceExt;
14
15const WORKGROUP_SIZE: u32 = 256;
17
18const SUM_I32_SHADER: &str = r"
20@group(0) @binding(0) var<storage, read> input: array<i32>;
21@group(0) @binding(1) var<storage, read_write> output: array<atomic<i32>>;
22
23var<workgroup> shared_data: array<i32, 256>;
24
25@compute @workgroup_size(256)
26fn sum_reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
27 @builtin(local_invocation_id) local_id: vec3<u32>,
28 @builtin(workgroup_id) workgroup_id: vec3<u32>) {
29 let tid = local_id.x;
30 let gid = global_id.x;
31 let input_size = arrayLength(&input);
32
33 // Load data into shared memory
34 if (gid < input_size) {
35 shared_data[tid] = input[gid];
36 } else {
37 shared_data[tid] = 0;
38 }
39 workgroupBarrier();
40
41 // Parallel reduction in shared memory
42 var stride = 128u;
43 while (stride > 0u) {
44 if (tid < stride && gid + stride < input_size) {
45 shared_data[tid] += shared_data[tid + stride];
46 }
47 workgroupBarrier();
48 stride = stride / 2u;
49 }
50
51 // First thread writes workgroup result
52 if (tid == 0u) {
53 atomicAdd(&output[0], shared_data[0]);
54 }
55}
56";
57
58#[allow(dead_code)]
60const SUM_F32_SHADER: &str = r"
61@group(0) @binding(0) var<storage, read> input: array<f32>;
62@group(0) @binding(1) var<storage, read_write> output: array<f32>;
63
64var<workgroup> shared_data: array<f32, 256>;
65
66@compute @workgroup_size(256)
67fn sum_reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
68 @builtin(local_invocation_id) local_id: vec3<u32>) {
69 let tid = local_id.x;
70 let gid = global_id.x;
71 let input_size = arrayLength(&input);
72
73 // Load data into shared memory
74 if (gid < input_size) {
75 shared_data[tid] = input[gid];
76 } else {
77 shared_data[tid] = 0.0;
78 }
79 workgroupBarrier();
80
81 // Parallel reduction in shared memory
82 var stride = 128u;
83 while (stride > 0u) {
84 if (tid < stride && gid + stride < input_size) {
85 shared_data[tid] += shared_data[tid + stride];
86 }
87 workgroupBarrier();
88 stride = stride / 2u;
89 }
90
91 // First thread writes workgroup result
92 if (tid == 0u) {
93 output[0] += shared_data[0];
94 }
95}
96";
97
98#[allow(dead_code)]
100const COUNT_SHADER: &str = r"
101@group(0) @binding(0) var<storage, read_write> output: array<atomic<u32>>;
102
103@compute @workgroup_size(256)
104fn count_kernel(@builtin(global_invocation_id) global_id: vec3<u32>) {
105 let array_size: u32 = @ARRAY_SIZE@;
106
107 if (global_id.x < array_size) {
108 atomicAdd(&output[0], 1u);
109 }
110}
111";
112
113#[allow(dead_code)]
115const MIN_I32_SHADER: &str = r"
116@group(0) @binding(0) var<storage, read> input: array<i32>;
117@group(0) @binding(1) var<storage, read_write> output: array<atomic<i32>>;
118
119var<workgroup> shared_data: array<i32, 256>;
120
121@compute @workgroup_size(256)
122fn min_reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
123 @builtin(local_invocation_id) local_id: vec3<u32>) {
124 let tid = local_id.x;
125 let gid = global_id.x;
126 let input_size = arrayLength(&input);
127
128 // Load data into shared memory
129 if (gid < input_size) {
130 shared_data[tid] = input[gid];
131 } else {
132 shared_data[tid] = 2147483647; // i32::MAX
133 }
134 workgroupBarrier();
135
136 // Parallel reduction in shared memory
137 var stride = 128u;
138 while (stride > 0u) {
139 if (tid < stride && gid + stride < input_size) {
140 shared_data[tid] = min(shared_data[tid], shared_data[tid + stride]);
141 }
142 workgroupBarrier();
143 stride = stride / 2u;
144 }
145
146 // First thread writes workgroup result
147 if (tid == 0u) {
148 atomicMin(&output[0], shared_data[0]);
149 }
150}
151";
152
153#[allow(dead_code)]
155const MAX_I32_SHADER: &str = r"
156@group(0) @binding(0) var<storage, read> input: array<i32>;
157@group(0) @binding(1) var<storage, read_write> output: array<atomic<i32>>;
158
159var<workgroup> shared_data: array<i32, 256>;
160
161@compute @workgroup_size(256)
162fn max_reduce(@builtin(global_invocation_id) global_id: vec3<u32>,
163 @builtin(local_invocation_id) local_id: vec3<u32>) {
164 let tid = local_id.x;
165 let gid = global_id.x;
166 let input_size = arrayLength(&input);
167
168 // Load data into shared memory
169 if (gid < input_size) {
170 shared_data[tid] = input[gid];
171 } else {
172 shared_data[tid] = -2147483648; // i32::MIN
173 }
174 workgroupBarrier();
175
176 // Parallel reduction in shared memory
177 var stride = 128u;
178 while (stride > 0u) {
179 if (tid < stride && gid + stride < input_size) {
180 shared_data[tid] = max(shared_data[tid], shared_data[tid + stride]);
181 }
182 workgroupBarrier();
183 stride = stride / 2u;
184 }
185
186 // First thread writes workgroup result
187 if (tid == 0u) {
188 atomicMax(&output[0], shared_data[0]);
189 }
190}
191";
192
193#[allow(clippy::too_many_lines)]
201#[allow(clippy::cast_possible_truncation)]
202pub async fn sum_i32(device: &wgpu::Device, queue: &wgpu::Queue, data: &Int32Array) -> Result<i32> {
203 let input_data: Vec<i32> = data.values().to_vec();
204 let input_size = input_data.len();
205
206 if input_size == 0 {
207 return Ok(0);
208 }
209
210 let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
212 label: Some("Input Buffer"),
213 contents: bytemuck::cast_slice(&input_data),
214 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
215 });
216
217 let output_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
219 label: Some("Output Buffer"),
220 contents: bytemuck::cast_slice(&[0i32]),
221 usage: wgpu::BufferUsages::STORAGE
222 | wgpu::BufferUsages::COPY_SRC
223 | wgpu::BufferUsages::COPY_DST,
224 });
225
226 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
228 label: Some("SUM i32 Shader"),
229 source: wgpu::ShaderSource::Wgsl(SUM_I32_SHADER.into()),
230 });
231
232 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
233 label: Some("Bind Group Layout"),
234 entries: &[
235 wgpu::BindGroupLayoutEntry {
236 binding: 0,
237 visibility: wgpu::ShaderStages::COMPUTE,
238 ty: wgpu::BindingType::Buffer {
239 ty: wgpu::BufferBindingType::Storage { read_only: true },
240 has_dynamic_offset: false,
241 min_binding_size: None,
242 },
243 count: None,
244 },
245 wgpu::BindGroupLayoutEntry {
246 binding: 1,
247 visibility: wgpu::ShaderStages::COMPUTE,
248 ty: wgpu::BindingType::Buffer {
249 ty: wgpu::BufferBindingType::Storage { read_only: false },
250 has_dynamic_offset: false,
251 min_binding_size: None,
252 },
253 count: None,
254 },
255 ],
256 });
257
258 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
259 label: Some("Pipeline Layout"),
260 bind_group_layouts: &[&bind_group_layout],
261 push_constant_ranges: &[],
262 });
263
264 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
265 label: Some("SUM i32 Pipeline"),
266 layout: Some(&pipeline_layout),
267 module: &shader,
268 entry_point: "sum_reduce",
269 compilation_options: wgpu::PipelineCompilationOptions::default(),
270 cache: None,
271 });
272
273 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
275 label: Some("Bind Group"),
276 layout: &bind_group_layout,
277 entries: &[
278 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
279 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
280 ],
281 });
282
283 let mut encoder = device
285 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("Compute Encoder") });
286
287 {
288 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
289 label: Some("Compute Pass"),
290 timestamp_writes: None,
291 });
292 compute_pass.set_pipeline(&compute_pipeline);
293 compute_pass.set_bind_group(0, &bind_group, &[]);
294
295 let workgroup_count = (input_size as u32).div_ceil(WORKGROUP_SIZE);
296 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
297 }
298
299 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
301 label: Some("Staging Buffer"),
302 size: 4, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
304 mapped_at_creation: false,
305 });
306
307 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, 4);
308 queue.submit(Some(encoder.finish()));
309
310 let buffer_slice = staging_buffer.slice(..);
312 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
313 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
314 sender.send(result).expect("Failed to send buffer mapping result through channel");
315 });
316 device.poll(wgpu::Maintain::Wait);
317
318 receiver
319 .receive()
320 .await
321 .ok_or_else(|| Error::Other("Failed to receive mapping result".to_string()))?
322 .map_err(|e| Error::Other(format!("Buffer mapping failed: {e:?}")))?;
323
324 let data = buffer_slice.get_mapped_range();
325 let result = i32::from_le_bytes(
326 data[0..4].try_into().expect("Buffer must contain at least 4 bytes for i32 result"),
327 );
328 drop(data);
329 staging_buffer.unmap();
330
331 Ok(result)
332}
333
334#[allow(clippy::unused_async)]
340pub async fn sum_f32(
341 _device: &wgpu::Device,
342 _queue: &wgpu::Queue,
343 _data: &Float32Array,
344) -> Result<f32> {
345 Err(Error::Other("f32 SUM not yet implemented".to_string()))
347}
348
349#[allow(clippy::unused_async)]
355pub async fn count(
356 _device: &wgpu::Device,
357 _queue: &wgpu::Queue,
358 data: &dyn Array,
359) -> Result<usize> {
360 Ok(data.len())
362}
363
364#[allow(clippy::too_many_lines)]
372#[allow(clippy::cast_possible_truncation)]
373pub async fn min_i32(device: &wgpu::Device, queue: &wgpu::Queue, data: &Int32Array) -> Result<i32> {
374 let input_data: Vec<i32> = data.values().to_vec();
375 let input_size = input_data.len();
376
377 if input_size == 0 {
378 return Ok(i32::MAX); }
380
381 let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
383 label: Some("MIN Input Buffer"),
384 contents: bytemuck::cast_slice(&input_data),
385 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
386 });
387
388 let output_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
390 label: Some("MIN Output Buffer"),
391 contents: bytemuck::cast_slice(&[i32::MAX]),
392 usage: wgpu::BufferUsages::STORAGE
393 | wgpu::BufferUsages::COPY_SRC
394 | wgpu::BufferUsages::COPY_DST,
395 });
396
397 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
399 label: Some("MIN i32 Shader"),
400 source: wgpu::ShaderSource::Wgsl(MIN_I32_SHADER.into()),
401 });
402
403 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
404 label: Some("MIN Bind Group Layout"),
405 entries: &[
406 wgpu::BindGroupLayoutEntry {
407 binding: 0,
408 visibility: wgpu::ShaderStages::COMPUTE,
409 ty: wgpu::BindingType::Buffer {
410 ty: wgpu::BufferBindingType::Storage { read_only: true },
411 has_dynamic_offset: false,
412 min_binding_size: None,
413 },
414 count: None,
415 },
416 wgpu::BindGroupLayoutEntry {
417 binding: 1,
418 visibility: wgpu::ShaderStages::COMPUTE,
419 ty: wgpu::BindingType::Buffer {
420 ty: wgpu::BufferBindingType::Storage { read_only: false },
421 has_dynamic_offset: false,
422 min_binding_size: None,
423 },
424 count: None,
425 },
426 ],
427 });
428
429 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
430 label: Some("MIN Pipeline Layout"),
431 bind_group_layouts: &[&bind_group_layout],
432 push_constant_ranges: &[],
433 });
434
435 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
436 label: Some("MIN i32 Pipeline"),
437 layout: Some(&pipeline_layout),
438 module: &shader,
439 entry_point: "min_reduce",
440 compilation_options: wgpu::PipelineCompilationOptions::default(),
441 cache: None,
442 });
443
444 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
446 label: Some("MIN Bind Group"),
447 layout: &bind_group_layout,
448 entries: &[
449 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
450 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
451 ],
452 });
453
454 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
456 label: Some("MIN Compute Encoder"),
457 });
458
459 {
460 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
461 label: Some("MIN Compute Pass"),
462 timestamp_writes: None,
463 });
464 compute_pass.set_pipeline(&compute_pipeline);
465 compute_pass.set_bind_group(0, &bind_group, &[]);
466
467 let workgroup_count = (input_size as u32).div_ceil(WORKGROUP_SIZE);
468 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
469 }
470
471 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
473 label: Some("MIN Staging Buffer"),
474 size: 4, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
476 mapped_at_creation: false,
477 });
478
479 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, 4);
480 queue.submit(Some(encoder.finish()));
481
482 let buffer_slice = staging_buffer.slice(..);
484 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
485 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
486 sender.send(result).expect("Failed to send buffer mapping result through channel");
487 });
488 device.poll(wgpu::Maintain::Wait);
489
490 receiver
491 .receive()
492 .await
493 .ok_or_else(|| Error::Other("Failed to receive mapping result".to_string()))?
494 .map_err(|e| Error::Other(format!("Buffer mapping failed: {e:?}")))?;
495
496 let data = buffer_slice.get_mapped_range();
497 let result = i32::from_le_bytes(
498 data[0..4].try_into().expect("Buffer must contain at least 4 bytes for i32 result"),
499 );
500 drop(data);
501 staging_buffer.unmap();
502
503 Ok(result)
504}
505
506#[allow(clippy::too_many_lines)]
514#[allow(clippy::cast_possible_truncation)]
515pub async fn max_i32(device: &wgpu::Device, queue: &wgpu::Queue, data: &Int32Array) -> Result<i32> {
516 let input_data: Vec<i32> = data.values().to_vec();
517 let input_size = input_data.len();
518
519 if input_size == 0 {
520 return Ok(i32::MIN); }
522
523 let input_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
525 label: Some("MAX Input Buffer"),
526 contents: bytemuck::cast_slice(&input_data),
527 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
528 });
529
530 let output_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
532 label: Some("MAX Output Buffer"),
533 contents: bytemuck::cast_slice(&[i32::MIN]),
534 usage: wgpu::BufferUsages::STORAGE
535 | wgpu::BufferUsages::COPY_SRC
536 | wgpu::BufferUsages::COPY_DST,
537 });
538
539 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
541 label: Some("MAX i32 Shader"),
542 source: wgpu::ShaderSource::Wgsl(MAX_I32_SHADER.into()),
543 });
544
545 let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
546 label: Some("MAX Bind Group Layout"),
547 entries: &[
548 wgpu::BindGroupLayoutEntry {
549 binding: 0,
550 visibility: wgpu::ShaderStages::COMPUTE,
551 ty: wgpu::BindingType::Buffer {
552 ty: wgpu::BufferBindingType::Storage { read_only: true },
553 has_dynamic_offset: false,
554 min_binding_size: None,
555 },
556 count: None,
557 },
558 wgpu::BindGroupLayoutEntry {
559 binding: 1,
560 visibility: wgpu::ShaderStages::COMPUTE,
561 ty: wgpu::BindingType::Buffer {
562 ty: wgpu::BufferBindingType::Storage { read_only: false },
563 has_dynamic_offset: false,
564 min_binding_size: None,
565 },
566 count: None,
567 },
568 ],
569 });
570
571 let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
572 label: Some("MAX Pipeline Layout"),
573 bind_group_layouts: &[&bind_group_layout],
574 push_constant_ranges: &[],
575 });
576
577 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
578 label: Some("MAX i32 Pipeline"),
579 layout: Some(&pipeline_layout),
580 module: &shader,
581 entry_point: "max_reduce",
582 compilation_options: wgpu::PipelineCompilationOptions::default(),
583 cache: None,
584 });
585
586 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
588 label: Some("MAX Bind Group"),
589 layout: &bind_group_layout,
590 entries: &[
591 wgpu::BindGroupEntry { binding: 0, resource: input_buffer.as_entire_binding() },
592 wgpu::BindGroupEntry { binding: 1, resource: output_buffer.as_entire_binding() },
593 ],
594 });
595
596 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
598 label: Some("MAX Compute Encoder"),
599 });
600
601 {
602 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
603 label: Some("MAX Compute Pass"),
604 timestamp_writes: None,
605 });
606 compute_pass.set_pipeline(&compute_pipeline);
607 compute_pass.set_bind_group(0, &bind_group, &[]);
608
609 let workgroup_count = (input_size as u32).div_ceil(WORKGROUP_SIZE);
610 compute_pass.dispatch_workgroups(workgroup_count, 1, 1);
611 }
612
613 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
615 label: Some("MAX Staging Buffer"),
616 size: 4, usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
618 mapped_at_creation: false,
619 });
620
621 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, 4);
622 queue.submit(Some(encoder.finish()));
623
624 let buffer_slice = staging_buffer.slice(..);
626 let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
627 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
628 sender.send(result).expect("Failed to send buffer mapping result through channel");
629 });
630 device.poll(wgpu::Maintain::Wait);
631
632 receiver
633 .receive()
634 .await
635 .ok_or_else(|| Error::Other("Failed to receive mapping result".to_string()))?
636 .map_err(|e| Error::Other(format!("Buffer mapping failed: {e:?}")))?;
637
638 let data = buffer_slice.get_mapped_range();
639 let result = i32::from_le_bytes(
640 data[0..4].try_into().expect("Buffer must contain at least 4 bytes for i32 result"),
641 );
642 drop(data);
643 staging_buffer.unmap();
644
645 Ok(result)
646}
647
648#[cfg(test)]
649mod tests {
650 use super::*;
651 use arrow::array::Int32Array;
652
653 #[tokio::test]
654 async fn test_count_returns_array_length() {
655 let data = Int32Array::from(vec![1, 2, 3, 4, 5]);
657
658 let instance = wgpu::Instance::default();
660 let Some(adapter) = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await
661 else {
662 eprintln!("Skipping GPU test (no GPU available)");
663 return;
664 };
665 let Ok((device, queue)) =
666 adapter.request_device(&wgpu::DeviceDescriptor::default(), None).await
667 else {
668 eprintln!("Skipping GPU test (failed to create device)");
669 return;
670 };
671
672 let result = count(&device, &queue, &data).await.unwrap();
673 assert_eq!(result, 5);
674 }
675
676 #[tokio::test]
677 async fn test_count_empty_array() {
678 let data = Int32Array::from(vec![] as Vec<i32>);
679
680 let instance = wgpu::Instance::default();
681 let Some(adapter) = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await
682 else {
683 eprintln!("Skipping GPU test (no GPU available)");
684 return;
685 };
686 let Ok((device, queue)) =
687 adapter.request_device(&wgpu::DeviceDescriptor::default(), None).await
688 else {
689 eprintln!("Skipping GPU test (failed to create device)");
690 return;
691 };
692
693 let result = count(&device, &queue, &data).await.unwrap();
694 assert_eq!(result, 0);
695 }
696
697 #[tokio::test]
698 async fn test_sum_f32_not_implemented() {
699 let data = Float32Array::from(vec![1.0, 2.0, 3.0]);
701
702 let instance = wgpu::Instance::default();
703 let Some(adapter) = instance.request_adapter(&wgpu::RequestAdapterOptions::default()).await
704 else {
705 eprintln!("Skipping GPU test (no GPU available)");
706 return;
707 };
708 let Ok((device, queue)) =
709 adapter.request_device(&wgpu::DeviceDescriptor::default(), None).await
710 else {
711 eprintln!("Skipping GPU test (failed to create device)");
712 return;
713 };
714
715 let result = sum_f32(&device, &queue, &data).await;
716 assert!(result.is_err());
717 assert!(result.unwrap_err().to_string().contains("not yet implemented"));
718 }
719}