wgpu_algorithms/scan/
scanner.rs1use super::pipeline::ScanPipeline;
2use crate::{common, context::Context};
3use std::sync::Arc;
4
5pub struct Scanner {
6 pipeline: ScanPipeline,
7 device: Arc<wgpu::Device>,
8 queue: Arc<wgpu::Queue>,
9 pub scratch_buffer: Option<wgpu::Buffer>,
10 scratch_size_bytes: u64,
11}
12
13impl Scanner {
14 pub fn new(ctx: &Context) -> Self {
15 Self {
16 pipeline: ScanPipeline::new(ctx),
17 device: Arc::new(ctx.device.clone()),
18 queue: Arc::new(ctx.queue.clone()),
19 scratch_buffer: None,
20 scratch_size_bytes: 0,
21 }
22 }
23
24 pub async fn scan(&mut self, input: &[u32]) -> Vec<u32> {
25 let data_buffer = common::buffers::create_storage_buffer(&self.device, input);
26 let dst_buffer =
27 common::buffers::create_empty_storage_buffer(&self.device, data_buffer.size());
28
29 self.scan_gpu_to_gpu(&data_buffer, &dst_buffer).await;
30
31 let size_bytes = (input.len() * 4) as u64;
32 common::buffers::download_buffer(&self.device, &self.queue, &dst_buffer, size_bytes).await
33 }
34
35 pub async fn scan_gpu_to_gpu(&mut self, input_buf: &wgpu::Buffer, output_buf: &wgpu::Buffer) {
36 let mut encoder = self
37 .device
38 .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
39 self.record_scan(&mut encoder, input_buf, output_buf);
40 self.queue.submit(Some(encoder.finish()));
41 }
42
43 pub fn record_scan(
44 &mut self,
45 encoder: &mut wgpu::CommandEncoder,
46 input_buf: &wgpu::Buffer,
47 output_buf: &wgpu::Buffer,
48 ) {
49 let size_bytes = input_buf.size();
50 let num_items = (size_bytes / 4) as u32;
51
52 self.prepare_scratch(num_items);
53
54 encoder.copy_buffer_to_buffer(input_buf, 0, output_buf, 0, size_bytes);
55
56 let scratch = self.scratch_buffer.as_ref().unwrap();
57
58 struct Level<'a> {
59 buf: &'a wgpu::Buffer,
60 offset: u64,
61 count: u32,
62 }
63
64 let mut levels = Vec::new();
65 levels.push(Level {
66 buf: output_buf,
67 offset: 0,
68 count: num_items,
69 });
70
71 let mut current_scratch_offset = 0u64;
72
73 loop {
74 let current = levels.last().unwrap();
75 if current.count <= 1 {
76 break;
77 }
78
79 let items_per_block = self.pipeline.vt * self.pipeline.block_size;
80
81 let aux_count = (current.count + items_per_block - 1) / items_per_block;
82 let aux_size = (aux_count * 4) as u64;
83 let aux_offset = crate::common::math::align_to(current_scratch_offset, 256);
84
85 self.pipeline.dispatch(
86 &self.device,
87 encoder,
88 &self.pipeline.scan_pipeline,
89 current.buf,
90 current.offset,
91 scratch,
92 aux_offset,
93 current.count,
94 );
95
96 levels.push(Level {
97 buf: scratch,
98 offset: aux_offset,
99 count: aux_count,
100 });
101 current_scratch_offset = aux_offset + aux_size;
102 }
103
104 for i in (0..levels.len() - 1).rev() {
105 let data_level = &levels[i];
106 let aux_level = &levels[i + 1];
107
108 self.pipeline.dispatch(
109 &self.device,
110 encoder,
111 &self.pipeline.add_pipeline,
112 data_level.buf,
113 data_level.offset,
114 aux_level.buf,
115 aux_level.offset,
116 data_level.count,
117 );
118 }
119 }
120
121 fn prepare_scratch(&mut self, num_items: u32) {
122 let needed_bytes = self.pipeline.get_scratch_size(num_items);
123 if self.scratch_buffer.is_none() || needed_bytes > self.scratch_size_bytes {
124 self.scratch_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
125 label: Some("Scanner Scratch"),
126 size: needed_bytes,
127 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
128 mapped_at_creation: false,
129 }));
130 self.scratch_size_bytes = needed_bytes;
131 }
132 }
133}