1use crate::common;
2use crate::context::Context;
3use crate::scan::Scanner;
4use crate::sort::pipeline::SortPipeline;
5use std::sync::Arc;
6
7struct SortWorkspace {
8 capacity_bytes: u64,
9 buf_a: wgpu::Buffer,
10 #[allow(dead_code)]
11 buf_b: wgpu::Buffer,
12 buf_hist: wgpu::Buffer,
13 buf_scanned_hist: wgpu::Buffer,
14 uniform_buffers: Vec<wgpu::Buffer>,
15 bind_groups: Vec<(wgpu::BindGroup, wgpu::BindGroup)>,
16}
17
18pub struct Sorter {
19 device: Arc<wgpu::Device>,
20 queue: Arc<wgpu::Queue>,
21 scanner: Scanner,
22 pipeline: SortPipeline,
23 workspace: Option<SortWorkspace>,
24}
25
26impl Sorter {
27 pub fn new(ctx: &Context) -> Self {
28 Self {
29 device: Arc::new(ctx.device.clone()),
30 queue: Arc::new(ctx.queue.clone()),
31 scanner: Scanner::new(ctx),
32 pipeline: SortPipeline::new(ctx),
33 workspace: None,
34 }
35 }
36
37 pub async fn sort(&mut self, input: &[u32]) -> Vec<u32> {
38 const GPU_THRESHOLD: usize = 1_000_000;
39
40 if input.len() < GPU_THRESHOLD {
41 let mut data = input.to_vec();
42 data.sort_unstable();
43 return data;
44 } else {
45 return self.sort_radix(input).await;
46 }
47 }
48
49 pub async fn sort_radix(&mut self, input: &[u32]) -> Vec<u32> {
50 let n = input.len() as u64;
51 let n_bytes = n * 4;
52
53 let need_realloc = if let Some(ws) = &self.workspace {
54 ws.capacity_bytes < n_bytes
55 } else {
56 true
57 };
58
59 if need_realloc {
60 self.allocate_workspace(n_bytes);
61 }
62
63 let ws = self.workspace.as_mut().unwrap();
64
65 self.queue
66 .write_buffer(&ws.buf_a, 0, bytemuck::cast_slice(input));
67
68 let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
69 let num_blocks = (n + items_per_block - 1) / items_per_block;
70
71 for i in 0..16 {
72 let bit = i * 2;
73 let uniform_data = [bit as u32, n as u32, num_blocks as u32, 0];
74 self.queue.write_buffer(
75 &ws.uniform_buffers[i],
76 0,
77 bytemuck::cast_slice(&uniform_data),
78 );
79 }
80
81 let mut encoder = self
82 .device
83 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
84 label: Some("Fused Sort"),
85 });
86
87 let max_dispatch = 65535;
88 let x_groups = if num_blocks as u32 > max_dispatch {
89 max_dispatch
90 } else {
91 num_blocks as u32
92 };
93 let y_groups = (num_blocks as u32 + max_dispatch - 1) / max_dispatch;
94
95 for i in 0..16 {
96 let (reduce_bg, scatter_bg) = &ws.bind_groups[i];
97
98 {
99 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
100 cpass.set_pipeline(&self.pipeline.reduce_pipeline);
101 cpass.set_bind_group(0, reduce_bg, &[]);
102 cpass.dispatch_workgroups(x_groups, y_groups, 1);
103 }
104
105 self.scanner
106 .record_scan(&mut encoder, &ws.buf_hist, &ws.buf_scanned_hist);
107
108 {
109 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
110 cpass.set_pipeline(&self.pipeline.scatter_pipeline);
111 cpass.set_bind_group(0, scatter_bg, &[]);
112 cpass.dispatch_workgroups(x_groups, y_groups, 1);
113 }
114 }
115
116 self.queue.submit(Some(encoder.finish()));
117
118 common::buffers::download_buffer(&self.device, &self.queue, &ws.buf_a, n_bytes).await
119 }
120
121 pub fn sort_resident(&mut self, input: &[u32]) -> &wgpu::Buffer {
122 let n = input.len() as u64;
123 let n_bytes = n * 4;
124
125 let need_realloc = if let Some(ws) = &self.workspace {
126 ws.capacity_bytes < n_bytes
127 } else {
128 true
129 };
130
131 if need_realloc {
132 self.allocate_workspace(n_bytes);
133 }
134
135 let ws = self.workspace.as_mut().unwrap();
136
137 self.queue
138 .write_buffer(&ws.buf_a, 0, bytemuck::cast_slice(input));
139
140 let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
141 let num_blocks = (n + items_per_block - 1) / items_per_block;
142
143 for i in 0..16 {
144 let bit = i * 2;
145 let uniform_data = [bit as u32, n as u32, num_blocks as u32, 0];
146 self.queue.write_buffer(
147 &ws.uniform_buffers[i],
148 0,
149 bytemuck::cast_slice(&uniform_data),
150 );
151 }
152
153 let mut encoder = self
154 .device
155 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
156 label: Some("Fused Sort Resident"),
157 });
158
159 let max_dispatch = 65535;
160 let x_groups = if num_blocks as u32 > max_dispatch {
161 max_dispatch
162 } else {
163 num_blocks as u32
164 };
165 let y_groups = (num_blocks as u32 + max_dispatch - 1) / max_dispatch;
166
167 for i in 0..16 {
168 let (reduce_bg, scatter_bg) = &ws.bind_groups[i];
169
170 {
171 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
172 cpass.set_pipeline(&self.pipeline.reduce_pipeline);
173 cpass.set_bind_group(0, reduce_bg, &[]);
174 cpass.dispatch_workgroups(x_groups, y_groups, 1);
175 }
176
177 self.scanner
178 .record_scan(&mut encoder, &ws.buf_hist, &ws.buf_scanned_hist);
179
180 {
181 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor::default());
182 cpass.set_pipeline(&self.pipeline.scatter_pipeline);
183 cpass.set_bind_group(0, scatter_bg, &[]);
184 cpass.dispatch_workgroups(x_groups, y_groups, 1);
185 }
186 }
187
188 self.queue.submit(Some(encoder.finish()));
189
190 &ws.buf_a
191 }
192
193 fn allocate_workspace(&mut self, requested_size: u64) {
194 let capacity = common::math::align_to(requested_size, 16 * 1024 * 1024);
195 let items_per_block = (self.pipeline.vt * self.pipeline.block_size) as u64;
196 let max_items = capacity / 4;
197 let max_blocks = (max_items + items_per_block - 1) / items_per_block;
198
199 let hist_bytes = max_blocks * 16;
200 let hist_bytes_aligned = common::math::align_to(hist_bytes, 256);
201
202 let buf_a = common::buffers::create_empty_storage_buffer(&self.device, capacity);
203 let buf_b = common::buffers::create_empty_storage_buffer(&self.device, capacity);
204 let buf_hist =
205 common::buffers::create_empty_storage_buffer(&self.device, hist_bytes_aligned);
206 let buf_scanned_hist =
207 common::buffers::create_empty_storage_buffer(&self.device, hist_bytes_aligned);
208
209 let mut uniform_buffers = Vec::with_capacity(16);
210 for _ in 0..16 {
211 uniform_buffers.push(self.device.create_buffer(&wgpu::BufferDescriptor {
212 label: Some("Sort Uniform"),
213 size: 16,
214 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
215 mapped_at_creation: false,
216 }));
217 }
218
219 let mut bind_groups = Vec::with_capacity(16);
220 for i in 0..16 {
221 let (source, dest) = if i % 2 == 0 {
222 (&buf_a, &buf_b)
223 } else {
224 (&buf_b, &buf_a)
225 };
226
227 let reduce_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
228 label: Some("Reduce BG"),
229 layout: &self.pipeline.bind_group_layout,
230 entries: &[
231 wgpu::BindGroupEntry {
232 binding: 0,
233 resource: source.as_entire_binding(),
234 },
235 wgpu::BindGroupEntry {
236 binding: 1,
237 resource: buf_hist.as_entire_binding(),
238 },
239 wgpu::BindGroupEntry {
240 binding: 2,
241 resource: dest.as_entire_binding(),
242 },
243 wgpu::BindGroupEntry {
244 binding: 3,
245 resource: uniform_buffers[i].as_entire_binding(),
246 },
247 ],
248 });
249
250 let scatter_bg = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
251 label: Some("Scatter BG"),
252 layout: &self.pipeline.bind_group_layout,
253 entries: &[
254 wgpu::BindGroupEntry {
255 binding: 0,
256 resource: source.as_entire_binding(),
257 },
258 wgpu::BindGroupEntry {
259 binding: 1,
260 resource: buf_scanned_hist.as_entire_binding(),
261 },
262 wgpu::BindGroupEntry {
263 binding: 2,
264 resource: dest.as_entire_binding(),
265 },
266 wgpu::BindGroupEntry {
267 binding: 3,
268 resource: uniform_buffers[i].as_entire_binding(),
269 },
270 ],
271 });
272
273 bind_groups.push((reduce_bg, scatter_bg));
274 }
275
276 self.workspace = Some(SortWorkspace {
277 capacity_bytes: capacity,
278 buf_a,
279 buf_b,
280 buf_hist,
281 buf_scanned_hist,
282 uniform_buffers,
283 bind_groups,
284 });
285 }
286}