Skip to main content

webgpu_groth16/gpu/
h_poly.rs

1//! H polynomial GPU pipeline dispatcher.
2//!
3//! Computes the quotient polynomial H(x) = (A(x)·B(x) − C(x)) / Z(x) entirely
4//! on the GPU using a single command encoder with the following steps:
5//!
6//! 1. **To Montgomery**: Convert A, B, C, twiddle factors, shift arrays, and
7//!    Z⁻¹ into Montgomery domain for efficient modular arithmetic
8//! 2. **Fused iNTT + Coset shift(A, B, C)**: Inverse NTT with shift factors
9//!    multiplied during write-back (avoids separate coset_shift dispatch)
10//! 3. **NTT(A, B, C)**: Forward NTT to get evaluation representations on the
11//!    coset
12//! 4. **Pointwise H = (A·B − C) · Z⁻¹**: Element-wise computation in evaluation
13//!    domain
14//! 5. **Fused iNTT + Inverse coset shift(H)**: iNTT with inverse shift fused in
15//! 6. **From Montgomery(H)**: Convert H out of Montgomery domain
16
17use wgpu::util::DeviceExt;
18
19use super::curve::GpuCurve;
20use super::{GpuContext, HPolyBuffers, compute_pass};
21
22impl<C: GpuCurve> GpuContext<C> {
23    pub fn execute_h_pipeline(&self, bufs: &HPolyBuffers<'_>, n: u32) {
24        let mut encoder = self.device.create_command_encoder(
25            &wgpu::CommandEncoderDescriptor {
26                label: Some("H Pipeline Encoder"),
27            },
28        );
29
30        let mont_bg = |buf: &wgpu::Buffer| {
31            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
32                label: Some("Montgomery BG"),
33                layout: &self.montgomery_bind_group_layout,
34                entries: &[wgpu::BindGroupEntry {
35                    binding: 0,
36                    resource: buf.as_entire_binding(),
37                }],
38            })
39        };
40
41        let ntt_bg = |data: &wgpu::Buffer, tw: &wgpu::Buffer| {
42            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
43                label: Some("NTT BG"),
44                layout: &self.ntt_bind_group_layout,
45                entries: &[
46                    wgpu::BindGroupEntry {
47                        binding: 0,
48                        resource: data.as_entire_binding(),
49                    },
50                    wgpu::BindGroupEntry {
51                        binding: 1,
52                        resource: tw.as_entire_binding(),
53                    },
54                ],
55            })
56        };
57
58        let fused_shift_bg = |shifts: &wgpu::Buffer| {
59            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
60                label: Some("NTT Fused Shift BG"),
61                layout: &self.ntt_fused_shift_bgl,
62                entries: &[wgpu::BindGroupEntry {
63                    binding: 0,
64                    resource: shifts.as_entire_binding(),
65                }],
66            })
67        };
68
69        let pointwise_fused_bg =
70            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
71                label: Some("Pointwise Fused BG"),
72                layout: &self.pointwise_fused_bind_group_layout,
73                entries: &[
74                    wgpu::BindGroupEntry {
75                        binding: 0,
76                        resource: bufs.a.as_entire_binding(),
77                    },
78                    wgpu::BindGroupEntry {
79                        binding: 1,
80                        resource: bufs.b.as_entire_binding(),
81                    },
82                    wgpu::BindGroupEntry {
83                        binding: 2,
84                        resource: bufs.c.as_entire_binding(),
85                    },
86                    wgpu::BindGroupEntry {
87                        binding: 3,
88                        resource: bufs.z_invs.as_entire_binding(),
89                    },
90                ],
91            });
92
93        #[cfg(feature = "profiling")]
94        let mut profiler_guard = self.profiler.lock().unwrap();
95        #[cfg(feature = "profiling")]
96        let mut scope = profiler_guard.scope("h_pipeline", &mut encoder);
97
98        // Keep uniform buffers alive until the queue is submitted!
99        let mut param_updates = Vec::new();
100
101        // 1. To Montgomery
102        for bg in [
103            mont_bg(bufs.a),
104            mont_bg(bufs.b),
105            mont_bg(bufs.c),
106            mont_bg(bufs.twiddles_inv),
107            mont_bg(bufs.twiddles_fwd),
108            mont_bg(bufs.shifts),
109            mont_bg(bufs.inv_shifts),
110            mont_bg(bufs.z_invs),
111        ] {
112            let mut pass = compute_pass!(scope, encoder, "to_montgomery");
113            pass.set_pipeline(&self.to_montgomery_pipeline);
114            pass.set_bind_group(0, &bg, &[]);
115            pass.dispatch_workgroups(
116                n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
117                1,
118                1,
119            );
120        }
121
122        macro_rules! encode_ntt {
123            (
124                $label:expr,
125                $data_buf:expr,
126                $tw_buf:expr,
127                $is_fused_shift:expr,
128                $shifts_buf:expr,
129                $is_h_fused_pointwise:expr
130            ) => {
131                if n <= C::NTT_TILE_SIZE {
132                    let bg = ntt_bg($data_buf, $tw_buf);
133                    if $is_h_fused_pointwise {
134                        let shifts_group1 =
135                            fused_shift_bg($shifts_buf.unwrap());
136                        let mut pass = compute_pass!(
137                            scope,
138                            encoder,
139                            concat!($label, "_fused_h")
140                        );
141                        pass.set_pipeline(
142                            &self.ntt_tile_fused_pointwise_pipeline,
143                        );
144                        pass.set_bind_group(0, &bg, &[]);
145                        pass.set_bind_group(1, &shifts_group1, &[]);
146                        pass.set_bind_group(2, &pointwise_fused_bg, &[]);
147                        pass.dispatch_workgroups(
148                            n.div_ceil(C::NTT_TILE_SIZE),
149                            1,
150                            1,
151                        );
152                    } else if $is_fused_shift {
153                        let shifts_group1 =
154                            fused_shift_bg($shifts_buf.unwrap());
155                        let mut pass = compute_pass!(
156                            scope,
157                            encoder,
158                            concat!($label, "_fused")
159                        );
160                        pass.set_pipeline(&self.ntt_fused_pipeline);
161                        pass.set_bind_group(0, &bg, &[]);
162                        pass.set_bind_group(1, &shifts_group1, &[]);
163                        pass.dispatch_workgroups(
164                            n.div_ceil(C::NTT_TILE_SIZE),
165                            1,
166                            1,
167                        );
168                    } else {
169                        let mut pass = compute_pass!(scope, encoder, $label);
170                        pass.set_pipeline(&self.ntt_pipeline);
171                        pass.set_bind_group(0, &bg, &[]);
172                        pass.dispatch_workgroups(
173                            n.div_ceil(C::NTT_TILE_SIZE),
174                            1,
175                            1,
176                        );
177                    }
178                } else {
179                    let mut log_n = 0u32;
180                    let mut m = n;
181                    while m > 1 {
182                        log_n += 1;
183                        m >>= 1;
184                    }
185
186                    let mut stage_params = [n, 0, log_n, 0];
187                    let params_buf = self.device.create_buffer_init(
188                        &wgpu::util::BufferInitDescriptor {
189                            label: Some("NTT Params Buffer"),
190                            contents: bytemuck::cast_slice(&stage_params),
191                            usage: wgpu::BufferUsages::UNIFORM
192                                | wgpu::BufferUsages::COPY_DST,
193                        },
194                    );
195
196                    let bg = self.device.create_bind_group(
197                        &wgpu::BindGroupDescriptor {
198                            label: Some("NTT Global BG"),
199                            layout: &self.ntt_params_bind_group_layout,
200                            entries: &[
201                                wgpu::BindGroupEntry {
202                                    binding: 0,
203                                    resource: $data_buf.as_entire_binding(),
204                                },
205                                wgpu::BindGroupEntry {
206                                    binding: 1,
207                                    resource: $tw_buf.as_entire_binding(),
208                                },
209                                wgpu::BindGroupEntry {
210                                    binding: 2,
211                                    resource: params_buf.as_entire_binding(),
212                                },
213                            ],
214                        },
215                    );
216
217                    if $is_h_fused_pointwise {
218                        let shifts_group1 =
219                            fused_shift_bg($shifts_buf.unwrap());
220                        let mut pass = compute_pass!(
221                            scope,
222                            encoder,
223                            concat!($label, "_bitreverse_fused_h")
224                        );
225                        pass.set_pipeline(
226                            &self.ntt_bitreverse_fused_pointwise_pipeline,
227                        );
228                        pass.set_bind_group(0, &bg, &[]);
229                        pass.set_bind_group(1, &shifts_group1, &[]);
230                        pass.set_bind_group(2, &pointwise_fused_bg, &[]);
231                        pass.dispatch_workgroups(
232                            n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
233                            1,
234                            1,
235                        );
236                    } else {
237                        let mut pass = compute_pass!(
238                            scope,
239                            encoder,
240                            concat!($label, "_bitreverse")
241                        );
242                        pass.set_pipeline(&self.ntt_bitreverse_pipeline);
243                        pass.set_bind_group(0, &bg, &[]);
244                        pass.dispatch_workgroups(
245                            n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
246                            1,
247                            1,
248                        );
249                    }
250
251                    let mut half_len = 1u32;
252                    if (log_n & 1) == 1 {
253                        stage_params[1] = half_len;
254                        let update_buf = self.device.create_buffer_init(
255                            &wgpu::util::BufferInitDescriptor {
256                                label: Some("NTT Params Update"),
257                                contents: bytemuck::cast_slice(&stage_params),
258                                usage: wgpu::BufferUsages::COPY_SRC,
259                            },
260                        );
261                        encoder.copy_buffer_to_buffer(
262                            &update_buf,
263                            0,
264                            &params_buf,
265                            0,
266                            16,
267                        );
268                        param_updates.push(update_buf);
269
270                        let mut pass = compute_pass!(
271                            scope,
272                            encoder,
273                            concat!($label, "_stage")
274                        );
275                        pass.set_pipeline(&self.ntt_global_stage_pipeline);
276                        pass.set_bind_group(0, &bg, &[]);
277                        pass.dispatch_workgroups(
278                            (n / 2).div_ceil(C::SCALAR_WORKGROUP_SIZE),
279                            1,
280                            1,
281                        );
282
283                        half_len = 2;
284                    }
285
286                    while half_len < n {
287                        stage_params[1] = half_len;
288                        let update_buf = self.device.create_buffer_init(
289                            &wgpu::util::BufferInitDescriptor {
290                                label: Some("NTT Params Update"),
291                                contents: bytemuck::cast_slice(&stage_params),
292                                usage: wgpu::BufferUsages::COPY_SRC,
293                            },
294                        );
295                        encoder.copy_buffer_to_buffer(
296                            &update_buf,
297                            0,
298                            &params_buf,
299                            0,
300                            16,
301                        );
302                        param_updates.push(update_buf);
303
304                        let mut pass = compute_pass!(
305                            scope,
306                            encoder,
307                            concat!($label, "_stage_radix4")
308                        );
309                        pass.set_pipeline(
310                            &self.ntt_global_stage_radix4_pipeline,
311                        );
312                        pass.set_bind_group(0, &bg, &[]);
313                        pass.dispatch_workgroups(
314                            (n / 4).div_ceil(C::SCALAR_WORKGROUP_SIZE),
315                            1,
316                            1,
317                        );
318
319                        half_len <<= 2;
320                    }
321
322                    if $is_fused_shift || $is_h_fused_pointwise {
323                        let shift_bg = self.device.create_bind_group(
324                            &wgpu::BindGroupDescriptor {
325                                label: Some("Coset Shift BG"),
326                                layout: &self.coset_shift_bind_group_layout,
327                                entries: &[
328                                    wgpu::BindGroupEntry {
329                                        binding: 0,
330                                        resource: $data_buf.as_entire_binding(),
331                                    },
332                                    wgpu::BindGroupEntry {
333                                        binding: 1,
334                                        resource: $shifts_buf
335                                            .unwrap()
336                                            .as_entire_binding(),
337                                    },
338                                ],
339                            },
340                        );
341                        let mut pass = compute_pass!(
342                            scope,
343                            encoder,
344                            concat!($label, "_shift")
345                        );
346                        pass.set_pipeline(&self.coset_shift_pipeline);
347                        pass.set_bind_group(0, &shift_bg, &[]);
348                        pass.dispatch_workgroups(
349                            n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
350                            1,
351                            1,
352                        );
353                    }
354
355                    param_updates.push(params_buf);
356                }
357            };
358        }
359
360        // 2. Fused iNTT + coset shift on A/B/C
361        encode_ntt!(
362            "intt_a",
363            bufs.a,
364            bufs.twiddles_inv,
365            true,
366            Some(bufs.shifts),
367            false
368        );
369        encode_ntt!(
370            "intt_b",
371            bufs.b,
372            bufs.twiddles_inv,
373            true,
374            Some(bufs.shifts),
375            false
376        );
377        encode_ntt!(
378            "intt_c",
379            bufs.c,
380            bufs.twiddles_inv,
381            true,
382            Some(bufs.shifts),
383            false
384        );
385
386        // 3. NTT on A/B/C
387        encode_ntt!(
388            "ntt_a",
389            bufs.a,
390            bufs.twiddles_fwd,
391            false,
392            None::<&wgpu::Buffer>,
393            false
394        );
395        encode_ntt!(
396            "ntt_b",
397            bufs.b,
398            bufs.twiddles_fwd,
399            false,
400            None::<&wgpu::Buffer>,
401            false
402        );
403        encode_ntt!(
404            "ntt_c",
405            bufs.c,
406            bufs.twiddles_fwd,
407            false,
408            None::<&wgpu::Buffer>,
409            false
410        );
411
412        // 4/5. Fused pointwise + iNTT(H) + inverse coset shift
413        encode_ntt!(
414            "intt_h",
415            bufs.h,
416            bufs.twiddles_inv,
417            false,
418            Some(bufs.inv_shifts),
419            true
420        );
421
422        // 6. From Montgomery on H
423        {
424            let bg = mont_bg(bufs.h);
425            let mut pass = compute_pass!(scope, encoder, "from_montgomery_h");
426            pass.set_pipeline(&self.from_montgomery_pipeline);
427            pass.set_bind_group(0, &bg, &[]);
428            pass.dispatch_workgroups(
429                n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
430                1,
431                1,
432            );
433        }
434
435        #[cfg(feature = "profiling")]
436        {
437            drop(scope);
438            profiler_guard.resolve_queries(&mut encoder);
439        }
440
441        self.queue.submit(Some(encoder.finish()));
442
443        // Ensure param buffers aren't dropped until the queue submission
444        // finishes
445        drop(param_updates);
446    }
447}