wgpu_algorithms/scan/
pipeline.rs

1use crate::common;
2use crate::context::Context;
3
4pub struct ScanPipeline {
5    pub bind_group_layout: wgpu::BindGroupLayout,
6    pub scan_pipeline: wgpu::ComputePipeline,
7    pub add_pipeline: wgpu::ComputePipeline,
8    pub vt: u32,
9    pub block_size: u32,
10}
11
12impl ScanPipeline {
13    pub fn new(ctx: &Context) -> Self {
14        let bind_group_layout =
15            ctx.device
16                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
17                    label: Some("Scan Layout"),
18                    entries: &[
19                        common::buffers::bind_entry(0, false, false),
20                        common::buffers::bind_entry(1, false, false),
21                    ],
22                });
23
24        let limits = ctx.device.limits();
25        let max_shared_mem = limits.max_compute_workgroup_storage_size;
26
27        // High End (M3/Desktop): 32KB+ shared mem -> Use VT=8, Block=256
28        // Low End (Mobile): <32KB shared mem -> Use VT=4, Block=128 (Lower register pressure)
29        let (vt, block_size) = if max_shared_mem >= 32768 {
30            (8, 256)
31        } else {
32            log::warn!("Low-end GPU detected. Downgrading to VT=4.");
33            (4, 128)
34        };
35
36        let config = common::shader::ShaderConfig { vt, block_size };
37
38        let scan_pipeline = common::shader::create_compute_pipeline(
39            &ctx.device,
40            &bind_group_layout,
41            include_str!("scan.wgsl"),
42            &format!("Scan VT{} Pipeline", vt),
43            "main",
44            Some(&config),
45        );
46
47        let add_pipeline = common::shader::create_compute_pipeline(
48            &ctx.device,
49            &bind_group_layout,
50            include_str!("add.wgsl"),
51            &format!("Add VT{} Pipeline", vt),
52            "main",
53            Some(&config),
54        );
55
56        Self {
57            bind_group_layout,
58            scan_pipeline,
59            add_pipeline,
60            vt,
61            block_size,
62        }
63    }
64
65    pub fn get_scratch_size(&self, num_items: u32) -> u64 {
66        let mut size = 0;
67        let mut current_items = num_items;
68
69        let items_per_block = self.vt * self.block_size;
70
71        while current_items > 1 {
72            let aux_count = (current_items + items_per_block - 1) / items_per_block;
73            let raw_size = (aux_count * 4) as u64;
74            let aligned_size = common::math::align_to(raw_size, 256);
75            size += aligned_size;
76            current_items = aux_count;
77        }
78        size
79    }
80
81    pub fn dispatch(
82        &self,
83        device: &wgpu::Device,
84        encoder: &mut wgpu::CommandEncoder,
85        pipeline: &wgpu::ComputePipeline,
86        data_buf: &wgpu::Buffer,
87        data_off: u64,
88        aux_buf: &wgpu::Buffer,
89        aux_off: u64,
90        num_items: u32,
91    ) {
92        let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
93        cpass.set_pipeline(pipeline);
94
95        let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
96            label: Some("Scan Dispatch BG"),
97            layout: &self.bind_group_layout,
98            entries: &[
99                wgpu::BindGroupEntry {
100                    binding: 0,
101                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
102                        buffer: data_buf,
103                        offset: data_off,
104                        size: None,
105                    }),
106                },
107                wgpu::BindGroupEntry {
108                    binding: 1,
109                    resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
110                        buffer: aux_buf,
111                        offset: aux_off,
112                        size: None,
113                    }),
114                },
115            ],
116        });
117        cpass.set_bind_group(0, &bg, &[]);
118
119        let items_per_block = self.vt * self.block_size;
120        let workgroups = common::math::calc_groups(num_items, items_per_block);
121
122        let max_dispatch = 65535;
123        let x = if workgroups > max_dispatch {
124            max_dispatch
125        } else {
126            workgroups
127        };
128        let y = (workgroups + max_dispatch - 1) / max_dispatch;
129
130        cpass.dispatch_workgroups(x, y, 1);
131    }
132}