wgpu_algorithms/scan/
scanner.rs

1use super::pipeline::ScanPipeline;
2use crate::{common, context::Context};
3use std::sync::Arc;
4
5pub struct Scanner {
6    pipeline: ScanPipeline,
7    device: Arc<wgpu::Device>,
8    queue: Arc<wgpu::Queue>,
9    pub scratch_buffer: Option<wgpu::Buffer>,
10    scratch_size_bytes: u64,
11}
12
13impl Scanner {
14    pub fn new(ctx: &Context) -> Self {
15        Self {
16            pipeline: ScanPipeline::new(ctx),
17            device: Arc::new(ctx.device.clone()),
18            queue: Arc::new(ctx.queue.clone()),
19            scratch_buffer: None,
20            scratch_size_bytes: 0,
21        }
22    }
23
24    pub async fn scan(&mut self, input: &[u32]) -> Vec<u32> {
25        let data_buffer = common::buffers::create_storage_buffer(&self.device, input);
26        let dst_buffer =
27            common::buffers::create_empty_storage_buffer(&self.device, data_buffer.size());
28
29        self.scan_gpu_to_gpu(&data_buffer, &dst_buffer).await;
30
31        let size_bytes = (input.len() * 4) as u64;
32        common::buffers::download_buffer(&self.device, &self.queue, &dst_buffer, size_bytes).await
33    }
34
35    pub async fn scan_gpu_to_gpu(&mut self, input_buf: &wgpu::Buffer, output_buf: &wgpu::Buffer) {
36        let mut encoder = self
37            .device
38            .create_command_encoder(&wgpu::CommandEncoderDescriptor::default());
39        self.record_scan(&mut encoder, input_buf, output_buf);
40        self.queue.submit(Some(encoder.finish()));
41    }
42
43    pub fn record_scan(
44        &mut self,
45        encoder: &mut wgpu::CommandEncoder,
46        input_buf: &wgpu::Buffer,
47        output_buf: &wgpu::Buffer,
48    ) {
49        let size_bytes = input_buf.size();
50        let num_items = (size_bytes / 4) as u32;
51
52        self.prepare_scratch(num_items);
53
54        encoder.copy_buffer_to_buffer(input_buf, 0, output_buf, 0, size_bytes);
55
56        let scratch = self.scratch_buffer.as_ref().unwrap();
57
58        struct Level<'a> {
59            buf: &'a wgpu::Buffer,
60            offset: u64,
61            count: u32,
62        }
63
64        let mut levels = Vec::new();
65        levels.push(Level {
66            buf: output_buf,
67            offset: 0,
68            count: num_items,
69        });
70
71        let mut current_scratch_offset = 0u64;
72
73        loop {
74            let current = levels.last().unwrap();
75            if current.count <= 1 {
76                break;
77            }
78
79            let items_per_block = self.pipeline.vt * self.pipeline.block_size;
80
81            let aux_count = (current.count + items_per_block - 1) / items_per_block;
82            let aux_size = (aux_count * 4) as u64;
83            let aux_offset = crate::common::math::align_to(current_scratch_offset, 256);
84
85            self.pipeline.dispatch(
86                &self.device,
87                encoder,
88                &self.pipeline.scan_pipeline,
89                current.buf,
90                current.offset,
91                scratch,
92                aux_offset,
93                current.count,
94            );
95
96            levels.push(Level {
97                buf: scratch,
98                offset: aux_offset,
99                count: aux_count,
100            });
101            current_scratch_offset = aux_offset + aux_size;
102        }
103
104        for i in (0..levels.len() - 1).rev() {
105            let data_level = &levels[i];
106            let aux_level = &levels[i + 1];
107
108            self.pipeline.dispatch(
109                &self.device,
110                encoder,
111                &self.pipeline.add_pipeline,
112                data_level.buf,
113                data_level.offset,
114                aux_level.buf,
115                aux_level.offset,
116                data_level.count,
117            );
118        }
119    }
120
121    fn prepare_scratch(&mut self, num_items: u32) {
122        let needed_bytes = self.pipeline.get_scratch_size(num_items);
123        if self.scratch_buffer.is_none() || needed_bytes > self.scratch_size_bytes {
124            self.scratch_buffer = Some(self.device.create_buffer(&wgpu::BufferDescriptor {
125                label: Some("Scanner Scratch"),
126                size: needed_bytes,
127                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
128                mapped_at_creation: false,
129            }));
130            self.scratch_size_bytes = needed_bytes;
131        }
132    }
133}