wgpu_algorithms/sort/
sorter.rs

1use crate::common;
2use crate::context::Context;
3use crate::scan::Scanner;
4use crate::sort::pipeline::SortPipeline;
5use std::sync::Arc;
6
7struct SortWorkspace {
8    capacity_bytes: u64,
9    buf_a: wgpu::Buffer,
10    #[allow(dead_code)]
11    buf_b: wgpu::Buffer,
12    buf_hist: wgpu::Buffer,
13    buf_scanned_hist: wgpu::Buffer,
14    uniform_buffers: Vec<wgpu::Buffer>,
15    bind_groups: Vec<(wgpu::BindGroup, wgpu::BindGroup)>,
16}
17
18pub struct Sorter {
19    device: Arc<wgpu::Device>,
20    queue: Arc<wgpu::Queue>,
21    scanner: Scanner,
22    pipeline: SortPipeline,
23    workspace: Option<SortWorkspace>,
24}
25
26impl Sorter {
27    pub fn new(ctx: &Context) -> Self {
28        Self {
29            device: Arc::new(ctx.device.clone()),
30            queue: Arc::new(ctx.queue.clone()),
31            scanner: Scanner::new(ctx),
32            pipeline: SortPipeline::new(ctx),
33            workspace: None,
34        }
35    }
36
37    pub async fn sort(&mut self, input: &[u32]) -> Vec<u32> {
38        const GPU_THRESHOLD: usize = 1_000_000;
39
40        if input.len() < GPU_THRESHOLD {
41            let mut data = input.to_vec();
42            data.sort_unstable();
43            return data;
44        } else {
45            return self.sort_radix(input).await;
46        }
47    }
48
49    pub async fn sort_radix(&mut self, input: &[u32]) -> Vec<u32> {
50        let n = input.len() as u64;
51        let n_bytes = n * 4;
52
53        let need_realloc = if let Some(ws) = &self.workspace {
54            ws.capacity_bytes < n_bytes
55        } else {
56            true
57        };
58
59        if need_realloc {
60            self.allocate_workspace(n_bytes);
61        }
62
63        let ws = self.workspace.as_mut().unwrap();
64
65        self.queue
66            .write_buffer(&ws.buf_a, 0, bytemuck::cast_slice(input));
67
68        let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
69        let num_blocks = (n + items_per_block - 1) / items_per_block;
70
71        for i in 0..16 {
72            let bit = i * 2;
73            let uniform_data = [bit as u32, n as u32, num_blocks as u32, 0];
74            self.queue.write_buffer(
75                &ws.uniform_buffers[i],
76                0,
77                bytemuck::cast_slice(&uniform_data),
78            );
79        }
80
81        let mut encoder = self
82            .device
83            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
84                label: Some("Fused Sort"),
85            });
86
87        let max_dispatch = 65535;
88        let x_groups = if num_blocks as u32 > max_dispatch {
89            max_dispatch
90        } else {
91            num_blocks as u32
92        };
93        let y_groups = (num_blocks as u32 + max_dispatch - 1) / max_dispatch;
94
95        for i in 0..16 {
96            let (reduce_bg, scatter_bg) = &ws.bind_groups[i];
97
98            {
99                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
100                cpass.set_pipeline(&self.pipeline.reduce_pipeline);
101                cpass.set_bind_group(0, reduce_bg, &[]);
102                cpass.dispatch_workgroups(x_groups, y_groups, 1);
103            }
104
105            self.scanner
106                .record_scan(&mut encoder, &ws.buf_hist, &ws.buf_scanned_hist);
107
108            {
109                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
110                cpass.set_pipeline(&self.pipeline.scatter_pipeline);
111                cpass.set_bind_group(0, scatter_bg, &[]);
112                cpass.dispatch_workgroups(x_groups, y_groups, 1);
113            }
114        }
115
116        self.queue.submit(Some(encoder.finish()));
117
118        common::buffers::download_buffer(&self.device, &self.queue, &ws.buf_a, n_bytes).await
119    }
120
121    pub fn sort_resident(&mut self, input: &[u32]) -> &wgpu::Buffer {
122        let n = input.len() as u64;
123        let n_bytes = n * 4;
124
125        let need_realloc = if let Some(ws) = &self.workspace {
126            ws.capacity_bytes < n_bytes
127        } else {
128            true
129        };
130
131        if need_realloc {
132            self.allocate_workspace(n_bytes);
133        }
134
135        let ws = self.workspace.as_mut().unwrap();
136
137        self.queue
138            .write_buffer(&ws.buf_a, 0, bytemuck::cast_slice(input));
139
140        let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
141        let num_blocks = (n + items_per_block - 1) / items_per_block;
142
143        for i in 0..16 {
144            let bit = i * 2;
145            let uniform_data = [bit as u32, n as u32, num_blocks as u32, 0];
146            self.queue.write_buffer(
147                &ws.uniform_buffers[i],
148                0,
149                bytemuck::cast_slice(&uniform_data),
150            );
151        }
152
153        let mut encoder = self
154            .device
155            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
156                label: Some("Fused Sort Resident"),
157            });
158
159        let max_dispatch = 65535;
160        let x_groups = if num_blocks as u32 > max_dispatch {
161            max_dispatch
162        } else {
163            num_blocks as u32
164        };
165        let y_groups = (num_blocks as u32 + max_dispatch - 1) / max_dispatch;
166
167        for i in 0..16 {
168            let (reduce_bg, scatter_bg) = &ws.bind_groups[i];
169
170            {
171                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
172                cpass.set_pipeline(&self.pipeline.reduce_pipeline);
173                cpass.set_bind_group(0, reduce_bg, &[]);
174                cpass.dispatch_workgroups(x_groups, y_groups, 1);
175            }
176
177            self.scanner
178                .record_scan(&mut encoder, &ws.buf_hist, &ws.buf_scanned_hist);
179
180            {
181                let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
182                cpass.set_pipeline(&self.pipeline.scatter_pipeline);
183                cpass.set_bind_group(0, scatter_bg, &[]);
184                cpass.dispatch_workgroups(x_groups, y_groups, 1);
185            }
186        }
187
188        self.queue.submit(Some(encoder.finish()));
189
190        &ws.buf_a
191    }
192
193    fn allocate_workspace(&mut self, requested_size: u64) {
194        let capacity = common::math::align_to(requested_size, 16 * 1024 * 1024);
195        let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
196        let max_items = capacity / 4;
197        let max_blocks = (max_items + items_per_block - 1) / items_per_block;
198
199        let hist_bytes = max_blocks * 16;
200        let hist_bytes_aligned = common::math::align_to(hist_bytes, 256);
201
202        let buf_a = common::buffers::create_empty_storage_buffer(&self.device, capacity);
203        let buf_b = common::buffers::create_empty_storage_buffer(&self.device, capacity);
204        let buf_hist =
205            common::buffers::create_empty_storage_buffer(&self.device, hist_bytes_aligned);
206        let buf_scanned_hist =
207            common::buffers::create_empty_storage_buffer(&self.device, hist_bytes_aligned);
208
209        let mut uniform_buffers = Vec::with_capacity(16);
210        for _ in 0..16 {
211            uniform_buffers.push(self.device.create_buffer(&wgpu::BufferDescriptor {
212                label: Some("Sort Uniform"),
213                size: 16,
214                usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
215                mapped_at_creation: false,
216            }));
217        }
218
219        let mut bind_groups = Vec::with_capacity(16);
220        for i in 0..16 {
221            let (source, dest) = if i % 2 == 0 {
222                (&buf_a, &buf_b)
223            } else {
224                (&buf_b, &buf_a)
225            };
226
227            let reduce_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
228                label: Some("Reduce BG"),
229                layout: &self.pipeline.bind_group_layout,
230                entries: &[
231                    wgpu::BindGroupEntry {
232                        binding: 0,
233                        resource: source.as_entire_binding(),
234                    },
235                    wgpu::BindGroupEntry {
236                        binding: 1,
237                        resource: buf_hist.as_entire_binding(),
238                    },
239                    wgpu::BindGroupEntry {
240                        binding: 2,
241                        resource: dest.as_entire_binding(),
242                    },
243                    wgpu::BindGroupEntry {
244                        binding: 3,
245                        resource: uniform_buffers[i].as_entire_binding(),
246                    },
247                ],
248            });
249
250            let scatter_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
251                label: Some("Scatter BG"),
252                layout: &self.pipeline.bind_group_layout,
253                entries: &[
254                    wgpu::BindGroupEntry {
255                        binding: 0,
256                        resource: source.as_entire_binding(),
257                    },
258                    wgpu::BindGroupEntry {
259                        binding: 1,
260                        resource: buf_scanned_hist.as_entire_binding(),
261                    },
262                    wgpu::BindGroupEntry {
263                        binding: 2,
264                        resource: dest.as_entire_binding(),
265                    },
266                    wgpu::BindGroupEntry {
267                        binding: 3,
268                        resource: uniform_buffers[i].as_entire_binding(),
269                    },
270                ],
271            });
272
273            bind_groups.push((reduce_bg, scatter_bg));
274        }
275
276        self.workspace = Some(SortWorkspace {
277            capacity_bytes: capacity,
278            buf_a,
279            buf_b,
280            buf_hist,
281            buf_scanned_hist,
282            uniform_buffers,
283            bind_groups,
284        });
285    }
286}