webgpu_groth16/bucket/
mod.rs1use crate::gpu::curve::{G1MsmDecomposition, GpuCurve};
15
16pub struct BucketData {
39 pub base_indices: Vec<u32>,
40 pub bucket_pointers: Vec<u32>,
42 pub bucket_sizes: Vec<u32>,
44 pub bucket_values: Vec<u32>,
46 pub window_starts: Vec<u32>,
48 pub window_counts: Vec<u32>,
50 pub num_windows: u32,
51 pub num_active_buckets: u32,
53 pub num_dispatched: u32,
56 pub orig_bucket_values: Vec<u32>,
59 pub orig_window_starts: Vec<u32>,
61 pub orig_window_counts: Vec<u32>,
63 pub reduce_starts: Vec<u32>,
65 pub reduce_counts: Vec<u32>,
67 pub has_chunks: bool,
69 pub bucket_width: usize,
70}
71
72impl BucketData {
73 #[cfg(feature = "timing")]
76 pub fn print_distribution_stats(&self, label: &str) {
77 if self.num_active_buckets == 0 {
78 eprintln!("[bucket-diag] {label}: 0 active buckets");
79 return;
80 }
81 let mut sizes: Vec<u32> = self.bucket_sizes.clone();
82 sizes.sort();
83 let n = sizes.len();
84 let total: u32 = sizes.iter().sum();
85 let max = *sizes.last().unwrap();
86 let min = *sizes.first().unwrap();
87 let mean = total as f64 / n as f64;
88 let median = sizes[n / 2];
89 let p90 = sizes[(n * 90) / 100];
90 let p95 = sizes[(n * 95) / 100];
91 let p99 = sizes[n.saturating_sub(1).min((n * 99) / 100)];
92
93 let over_64 = sizes.iter().filter(|&&s| s > 64).count();
94 let over_256 = sizes.iter().filter(|&&s| s > 256).count();
95 let over_1024 = sizes.iter().filter(|&&s| s > 1024).count();
96
97 eprintln!(
98 "[bucket-diag] {label}: {n} active buckets, {total} total points, \
99 c={}",
100 self.bucket_width
101 );
102 eprintln!(
103 "[bucket-diag] min={min} max={max} mean={mean:.1} \
104 median={median}"
105 );
106 eprintln!("[bucket-diag] p90={p90} p95={p95} p99={p99}");
107 eprintln!(
108 "[bucket-diag] >64: {over_64} >256: {over_256} >1024: \
109 {over_1024}"
110 );
111
112 for w in 0..self.num_windows as usize {
114 let start = self.window_starts[w] as usize;
115 let count = self.window_counts[w] as usize;
116 if count == 0 {
117 continue;
118 }
119 let w_sizes: Vec<u32> = (start..start + count)
120 .map(|i| self.bucket_sizes[i])
121 .collect();
122 let w_max = *w_sizes.iter().max().unwrap();
123 let w_total: u32 = w_sizes.iter().sum();
124 let max_idx = w_sizes.iter().position(|&s| s == w_max).unwrap();
126 let max_val = self.bucket_values[start + max_idx];
127 if w_max > 32 {
128 eprintln!(
129 "[bucket-diag] window {w}: {count} buckets, \
130 max_size={w_max} (val={max_val}), total={w_total}"
131 );
132 }
133 }
134 }
135}
136
137fn build_bucket_data<G: GpuCurve>(
157 all_windows: &[Vec<(u32, bool)>],
158 c: usize,
159) -> BucketData {
160 let num_windows = all_windows.iter().map(|w| w.len()).max().unwrap_or(0);
161 let num_buckets = (1usize << (c - 1)) + 1;
162
163 let mut base_indices = Vec::new();
165 let mut orig_pointers = Vec::new();
166 let mut orig_sizes = Vec::new();
167 let mut orig_values = Vec::new();
168 let mut orig_window_starts = Vec::new();
169 let mut orig_window_counts = Vec::new();
170
171 for w in 0..num_windows {
172 let mut buckets: Vec<Vec<u32>> = vec![Vec::new(); num_buckets];
173
174 for (i, windows) in all_windows.iter().enumerate() {
175 if w < windows.len() {
176 let (abs, neg) = windows[w];
177 if abs != 0 {
178 let entry = if neg {
179 i as u32 | G::MSM_INDEX_SIGN_BIT
180 } else {
181 i as u32
182 };
183 buckets[abs as usize].push(entry);
184 }
185 }
186 }
187
188 orig_window_starts.push(orig_values.len() as u32);
189 let mut count = 0u32;
190
191 for (val, indices) in buckets.into_iter().enumerate() {
192 if !indices.is_empty() {
193 orig_pointers.push(base_indices.len() as u32);
194 orig_sizes.push(indices.len() as u32);
195 orig_values.push(val as u32);
196 base_indices.extend(indices);
197 count += 1;
198 }
199 }
200 orig_window_counts.push(count);
201 }
202
203 let num_active_buckets = orig_sizes.len() as u32;
204
205 let mut bucket_pointers = Vec::new();
207 let mut bucket_sizes = Vec::new();
208 let mut bucket_values = Vec::new();
209 let mut window_starts = Vec::new();
210 let mut window_counts = Vec::new();
211 let mut reduce_starts = Vec::new();
212 let mut reduce_counts = Vec::new();
213 let mut has_chunks = false;
214
215 for w in 0..num_windows {
216 let w_start = orig_window_starts[w] as usize;
217 let w_count = orig_window_counts[w] as usize;
218 window_starts.push(bucket_pointers.len() as u32);
219 let mut dispatched_in_window = 0u32;
220
221 for b in 0..w_count {
222 let orig_idx = w_start + b;
223 let ptr = orig_pointers[orig_idx];
224 let size = orig_sizes[orig_idx];
225 let val = orig_values[orig_idx];
226
227 let sub_start = bucket_pointers.len() as u32;
228
229 if size <= G::MSM_MAX_CHUNK_SIZE {
230 bucket_pointers.push(ptr);
231 bucket_sizes.push(size);
232 bucket_values.push(val);
233 reduce_starts.push(sub_start);
234 reduce_counts.push(1);
235 dispatched_in_window += 1;
236 } else {
237 has_chunks = true;
238 let num_chunks = size.div_ceil(G::MSM_MAX_CHUNK_SIZE);
239 for chunk in 0..num_chunks {
240 let chunk_start = ptr + chunk * G::MSM_MAX_CHUNK_SIZE;
241 let chunk_size = (size - chunk * G::MSM_MAX_CHUNK_SIZE)
242 .min(G::MSM_MAX_CHUNK_SIZE);
243 bucket_pointers.push(chunk_start);
244 bucket_sizes.push(chunk_size);
245 bucket_values.push(val);
246 dispatched_in_window += 1;
247 }
248 reduce_starts.push(sub_start);
249 reduce_counts.push(num_chunks);
250 }
251 }
252 window_counts.push(dispatched_in_window);
253 }
254
255 let num_dispatched = bucket_pointers.len() as u32;
256
257 BucketData {
258 base_indices,
259 bucket_pointers,
260 bucket_sizes,
261 bucket_values,
262 window_starts,
263 window_counts,
264 num_windows: num_windows as u32,
265 num_active_buckets,
266 num_dispatched,
267 orig_bucket_values: orig_values,
268 orig_window_starts,
269 orig_window_counts,
270 reduce_starts,
271 reduce_counts,
272 has_chunks,
273 bucket_width: c,
274 }
275}
276
277pub fn optimal_glv_c<G: GpuCurve>(n: usize) -> usize {
278 G::g1_msm_bucket_width(n)
279}
280
281pub fn compute_bucket_sorting<G: GpuCurve>(
282 scalars: &[G::Scalar],
283) -> BucketData {
284 compute_bucket_sorting_with_width::<G>(scalars, G::bucket_width())
285}
286
287pub fn compute_bucket_sorting_with_width<G: GpuCurve>(
288 scalars: &[G::Scalar],
289 c: usize,
290) -> BucketData {
291 let all_windows: Vec<Vec<(u32, bool)>> = scalars
292 .iter()
293 .map(|s| G::scalar_to_signed_windows(s, c))
294 .collect();
295 build_bucket_data::<G>(&all_windows, c)
296}
297
298pub fn compute_glv_bucket_sorting<G: GpuCurve>(
309 scalars: &[G::Scalar],
310 bases_bytes: &[u8],
311 phi_bases_bytes: &[u8],
312 c: usize,
313) -> (Vec<u8>, BucketData) {
314 if !G::HAS_G1_GLV {
315 let bd = compute_bucket_sorting_with_width::<G>(scalars, c);
316 return (bases_bytes.to_vec(), bd);
317 }
318
319 let n = scalars.len();
320 debug_assert_eq!(bases_bytes.len(), n * G::G1_GPU_BYTES);
321 debug_assert_eq!(phi_bases_bytes.len(), n * G::G1_GPU_BYTES);
322
323 let mut combined_bases = Vec::with_capacity(n * 2 * G::G1_GPU_BYTES);
325 let mut all_windows: Vec<Vec<(u32, bool)>> = Vec::with_capacity(n * 2);
326
327 for (i, scalar) in scalars.iter().enumerate() {
328 if let Some((k1_windows, k1_neg, k2_windows, k2_neg)) =
329 G::decompose_g1_msm_scalar_glv_windows(scalar, c)
330 {
331 let src_start = i * G::G1_GPU_BYTES;
332 let mut p_bytes =
333 bases_bytes[src_start..src_start + G::G1_GPU_BYTES].to_vec();
334 if k1_neg {
335 G::negate_g1_base_bytes(&mut p_bytes);
336 }
337 combined_bases.extend_from_slice(&p_bytes);
338
339 let mut phi_bytes = phi_bases_bytes
340 [src_start..src_start + G::G1_GPU_BYTES]
341 .to_vec();
342 if k2_neg {
343 G::negate_g1_base_bytes(&mut phi_bytes);
344 }
345 combined_bases.extend_from_slice(&phi_bytes);
346
347 all_windows.push(k1_windows);
348 all_windows.push(k2_windows);
349 } else if let G1MsmDecomposition::Standard { windows } =
350 G::decompose_g1_msm_scalar(scalar, c)
351 {
352 let src_start = i * G::G1_GPU_BYTES;
353 combined_bases.extend_from_slice(
354 &bases_bytes[src_start..src_start + G::G1_GPU_BYTES],
355 );
356 all_windows.push(windows);
357 }
358 }
359
360 (combined_bases, build_bucket_data::<G>(&all_windows, c))
361}
362
363pub fn compute_glv_bucket_data<G: GpuCurve>(
371 scalars: &[G::Scalar],
372 c: usize,
373) -> BucketData {
374 if !G::HAS_G1_GLV {
375 return compute_bucket_sorting_with_width::<G>(scalars, c);
376 }
377
378 let n = scalars.len();
379 let mut all_windows: Vec<Vec<(u32, bool)>> = Vec::with_capacity(n * 2);
380
381 for scalar in scalars.iter() {
382 if let Some((mut k1_windows, k1_neg, mut k2_windows, k2_neg)) =
383 G::decompose_g1_msm_scalar_glv_windows(scalar, c)
384 {
385 if k1_neg {
386 for w in &mut k1_windows {
387 if w.0 != 0 {
388 w.1 = !w.1;
389 }
390 }
391 }
392 all_windows.push(k1_windows);
393
394 if k2_neg {
395 for w in &mut k2_windows {
396 if w.0 != 0 {
397 w.1 = !w.1;
398 }
399 }
400 }
401 all_windows.push(k2_windows);
402 } else if let G1MsmDecomposition::Standard { windows } =
403 G::decompose_g1_msm_scalar(scalar, c)
404 {
405 all_windows.push(windows);
406 }
407 }
408
409 build_bucket_data::<G>(&all_windows, c)
410}
411
412#[cfg(test)]
413mod tests;