wgpu_algorithms/scan/
pipeline.rs1use crate::common;
2use crate::context::Context;
3
4pub struct ScanPipeline {
5 pub bind_group_layout: wgpu::BindGroupLayout,
6 pub scan_pipeline: wgpu::ComputePipeline,
7 pub add_pipeline: wgpu::ComputePipeline,
8 pub vt: u32,
9 pub block_size: u32,
10}
11
12impl ScanPipeline {
13 pub fn new(ctx: &Context) -> Self {
14 let bind_group_layout =
15 ctx.device
16 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
17 label: Some("Scan Layout"),
18 entries: &[
19 common::buffers::bind_entry(0, false, false),
20 common::buffers::bind_entry(1, false, false),
21 ],
22 });
23
24 let limits = ctx.device.limits();
25 let max_shared_mem = limits.max_compute_workgroup_storage_size;
26
27 let (vt, block_size) = if max_shared_mem >= 32768 {
30 (8, 256)
31 } else {
32 log::warn!("Low-end GPU detected. Downgrading to VT=4.");
33 (4, 128)
34 };
35
36 let config = common::shader::ShaderConfig { vt, block_size };
37
38 let scan_pipeline = common::shader::create_compute_pipeline(
39 &ctx.device,
40 &bind_group_layout,
41 include_str!("scan.wgsl"),
42 &format!("Scan VT{} Pipeline", vt),
43 "main",
44 Some(&config),
45 );
46
47 let add_pipeline = common::shader::create_compute_pipeline(
48 &ctx.device,
49 &bind_group_layout,
50 include_str!("add.wgsl"),
51 &format!("Add VT{} Pipeline", vt),
52 "main",
53 Some(&config),
54 );
55
56 Self {
57 bind_group_layout,
58 scan_pipeline,
59 add_pipeline,
60 vt,
61 block_size,
62 }
63 }
64
65 pub fn get_scratch_size(&self, num_items: u32) -> u64 {
66 let mut size = 0;
67 let mut current_items = num_items;
68
69 let items_per_block = self.vt * self.block_size;
70
71 while current_items > 1 {
72 let aux_count = (current_items + items_per_block - 1) / items_per_block;
73 let raw_size = (aux_count * 4) as u64;
74 let aligned_size = common::math::align_to(raw_size, 256);
75 size += aligned_size;
76 current_items = aux_count;
77 }
78 size
79 }
80
81 pub fn dispatch(
82 &self,
83 device: &wgpu::Device,
84 encoder: &mut wgpu::CommandEncoder,
85 pipeline: &wgpu::ComputePipeline,
86 data_buf: &wgpu::Buffer,
87 data_off: u64,
88 aux_buf: &wgpu::Buffer,
89 aux_off: u64,
90 num_items: u32,
91 ) {
92 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
93 cpass.set_pipeline(pipeline);
94
95 let bg = device.create_bind_group(&wgpu::BindGroupDescriptor {
96 label: Some("Scan Dispatch BG"),
97 layout: &self.bind_group_layout,
98 entries: &[
99 wgpu::BindGroupEntry {
100 binding: 0,
101 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
102 buffer: data_buf,
103 offset: data_off,
104 size: None,
105 }),
106 },
107 wgpu::BindGroupEntry {
108 binding: 1,
109 resource: wgpu::BindingResource::Buffer(wgpu::BufferBinding {
110 buffer: aux_buf,
111 offset: aux_off,
112 size: None,
113 }),
114 },
115 ],
116 });
117 cpass.set_bind_group(0, &bg, &[]);
118
119 let items_per_block = self.vt * self.block_size;
120 let workgroups = common::math::calc_groups(num_items, items_per_block);
121
122 let max_dispatch = 65535;
123 let x = if workgroups > max_dispatch {
124 max_dispatch
125 } else {
126 workgroups
127 };
128 let y = (workgroups + max_dispatch - 1) / max_dispatch;
129
130 cpass.dispatch_workgroups(x, y, 1);
131 }
132}