Skip to main content

webgpu_groth16/gpu/
ntt.rs

1//! NTT and polynomial operation dispatchers.
2//!
3//! Contains GPU compute dispatch methods for:
4//! - Montgomery domain conversion (to/from)
5//! - Number Theoretic Transform (local tile and multi-stage global)
6//! - Coset shift (multiply by powers of the multiplicative generator)
7//! - Pointwise polynomial operations (H = (A·B − C) / Z)
8
9use wgpu::util::DeviceExt;
10
11use super::GpuContext;
12use super::curve::GpuCurve;
13
14impl<C: GpuCurve> GpuContext<C> {
15    pub fn execute_to_montgomery(
16        &self,
17        buffer: &wgpu::Buffer,
18        num_elements: u32,
19    ) {
20        let bind_group =
21            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
22                label: Some("To Montgomery Bind Group"),
23                layout: &self.montgomery_bind_group_layout,
24                entries: &[wgpu::BindGroupEntry {
25                    binding: 0,
26                    resource: buffer.as_entire_binding(),
27                }],
28            });
29        let mut encoder = self.device.create_command_encoder(
30            &wgpu::CommandEncoderDescriptor { label: None },
31        );
32        {
33            let mut cpass =
34                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
35                    label: None,
36                    timestamp_writes: None,
37                });
38            cpass.set_pipeline(&self.to_montgomery_pipeline);
39            cpass.set_bind_group(0, &bind_group, &[]);
40            cpass.dispatch_workgroups(
41                num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
42                1,
43                1,
44            );
45        }
46        self.queue.submit(Some(encoder.finish()));
47    }
48
49    pub fn execute_from_montgomery(
50        &self,
51        buffer: &wgpu::Buffer,
52        num_elements: u32,
53    ) {
54        let bind_group =
55            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
56                label: Some("From Montgomery Bind Group"),
57                layout: &self.montgomery_bind_group_layout,
58                entries: &[wgpu::BindGroupEntry {
59                    binding: 0,
60                    resource: buffer.as_entire_binding(),
61                }],
62            });
63        let mut encoder = self.device.create_command_encoder(
64            &wgpu::CommandEncoderDescriptor { label: None },
65        );
66        {
67            let mut cpass =
68                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
69                    label: None,
70                    timestamp_writes: None,
71                });
72            cpass.set_pipeline(&self.from_montgomery_pipeline);
73            cpass.set_bind_group(0, &bind_group, &[]);
74            cpass.dispatch_workgroups(
75                num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
76                1,
77                1,
78            );
79        }
80        self.queue.submit(Some(encoder.finish()));
81    }
82
83    pub fn execute_ntt(
84        &self,
85        data_buffer: &wgpu::Buffer,
86        twiddles_buffer: &wgpu::Buffer,
87        num_elements: u32,
88    ) {
89        if num_elements > C::NTT_TILE_SIZE {
90            self.execute_ntt_global(data_buffer, twiddles_buffer, num_elements);
91            return;
92        }
93
94        let bind_group =
95            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
96                label: Some("NTT Bind Group"),
97                layout: &self.ntt_bind_group_layout,
98                entries: &[
99                    wgpu::BindGroupEntry {
100                        binding: 0,
101                        resource: data_buffer.as_entire_binding(),
102                    },
103                    wgpu::BindGroupEntry {
104                        binding: 1,
105                        resource: twiddles_buffer.as_entire_binding(),
106                    },
107                ],
108            });
109        let mut encoder = self.device.create_command_encoder(
110            &wgpu::CommandEncoderDescriptor {
111                label: Some("NTT Encoder"),
112            },
113        );
114        {
115            let mut cpass =
116                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
117                    label: Some("NTT Pass"),
118                    timestamp_writes: None,
119                });
120            cpass.set_pipeline(&self.ntt_pipeline);
121            cpass.set_bind_group(0, &bind_group, &[]);
122            cpass.dispatch_workgroups(
123                num_elements.div_ceil(C::NTT_TILE_SIZE),
124                1,
125                1,
126            );
127        }
128        self.queue.submit(Some(encoder.finish()));
129    }
130
131    /// Multi-stage global NTT for sizes > NTT_TILE_SIZE (512).
132    ///
133    /// Algorithm:
134    /// 1. Bit-reversal permutation (in-place)
135    /// 2. Iterative butterfly stages: for each `half_len` in 1, 2, 4, ..., n/2,
136    ///    dispatches workgroups that combine pairs of elements using twiddle
137    ///    factors
138    ///
139    /// Each stage updates a uniform buffer with `[n, half_len, log_n, 0]` so
140    /// the shader knows the current butterfly geometry.
141    pub fn execute_ntt_global(
142        &self,
143        data_buffer: &wgpu::Buffer,
144        twiddles_buffer: &wgpu::Buffer,
145        num_elements: u32,
146    ) {
147        let mut log_n = 0u32;
148        let mut m = num_elements;
149        while m > 1 {
150            log_n += 1;
151            m >>= 1;
152        }
153
154        let mut encoder = self.device.create_command_encoder(
155            &wgpu::CommandEncoderDescriptor {
156                label: Some("NTT Global Encoder"),
157            },
158        );
159
160        let mut stage_params = [0u32; 4];
161        stage_params[0] = num_elements;
162        stage_params[2] = log_n;
163        let params_buf =
164            self.device
165                .create_buffer_init(&wgpu::util::BufferInitDescriptor {
166                    label: Some("NTT Params Buffer"),
167                    contents: bytemuck::cast_slice(&stage_params),
168                    usage: wgpu::BufferUsages::UNIFORM
169                        | wgpu::BufferUsages::COPY_DST,
170                });
171
172        let make_bind_group = |params_buf: &wgpu::Buffer| {
173            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
174                label: Some("NTT Global Bind Group"),
175                layout: &self.ntt_params_bind_group_layout,
176                entries: &[
177                    wgpu::BindGroupEntry {
178                        binding: 0,
179                        resource: data_buffer.as_entire_binding(),
180                    },
181                    wgpu::BindGroupEntry {
182                        binding: 1,
183                        resource: twiddles_buffer.as_entire_binding(),
184                    },
185                    wgpu::BindGroupEntry {
186                        binding: 2,
187                        resource: params_buf.as_entire_binding(),
188                    },
189                ],
190            })
191        };
192
193        // Bit-reversal pass
194        {
195            let bg = make_bind_group(&params_buf);
196            let mut pass =
197                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
198                    label: Some("NTT BitReverse Pass"),
199                    timestamp_writes: None,
200                });
201            pass.set_pipeline(&self.ntt_bitreverse_pipeline);
202            pass.set_bind_group(0, &bg, &[]);
203            pass.dispatch_workgroups(
204                num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
205                1,
206                1,
207            );
208        }
209
210        // Butterfly stages (radix-4 with optional first radix-2 when needed)
211        let mut half_len = 1u32;
212        let mut param_updates: Vec<wgpu::Buffer> = Vec::new();
213
214        if (log_n & 1) == 1 {
215            stage_params[1] = half_len;
216            let update_buf = self.device.create_buffer_init(
217                &wgpu::util::BufferInitDescriptor {
218                    label: Some("NTT Params Update"),
219                    contents: bytemuck::cast_slice(&stage_params),
220                    usage: wgpu::BufferUsages::COPY_SRC,
221                },
222            );
223            encoder.copy_buffer_to_buffer(&update_buf, 0, &params_buf, 0, 16);
224            param_updates.push(update_buf);
225
226            let bg = make_bind_group(&params_buf);
227            let mut pass =
228                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
229                    label: Some("NTT Global Stage R2 Pass"),
230                    timestamp_writes: None,
231                });
232            pass.set_pipeline(&self.ntt_global_stage_pipeline);
233            pass.set_bind_group(0, &bg, &[]);
234            pass.dispatch_workgroups(
235                (num_elements / 2).div_ceil(C::SCALAR_WORKGROUP_SIZE),
236                1,
237                1,
238            );
239
240            half_len = 2;
241        }
242
243        while half_len < num_elements {
244            stage_params[1] = half_len;
245            let update_buf = self.device.create_buffer_init(
246                &wgpu::util::BufferInitDescriptor {
247                    label: Some("NTT Params Update"),
248                    contents: bytemuck::cast_slice(&stage_params),
249                    usage: wgpu::BufferUsages::COPY_SRC,
250                },
251            );
252            encoder.copy_buffer_to_buffer(&update_buf, 0, &params_buf, 0, 16);
253            param_updates.push(update_buf);
254
255            let bg = make_bind_group(&params_buf);
256            let mut pass =
257                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
258                    label: Some("NTT Global Stage R4 Pass"),
259                    timestamp_writes: None,
260                });
261            pass.set_pipeline(&self.ntt_global_stage_radix4_pipeline);
262            pass.set_bind_group(0, &bg, &[]);
263            pass.dispatch_workgroups(
264                (num_elements / 4).div_ceil(C::SCALAR_WORKGROUP_SIZE),
265                1,
266                1,
267            );
268
269            half_len <<= 2;
270        }
271
272        self.queue.submit(Some(encoder.finish()));
273    }
274
275    pub fn execute_coset_shift(
276        &self,
277        data_buffer: &wgpu::Buffer,
278        shifts_buffer: &wgpu::Buffer,
279        num_elements: u32,
280    ) {
281        let bind_group =
282            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
283                label: Some("Coset Shift Bind Group"),
284                layout: &self.coset_shift_bind_group_layout,
285                entries: &[
286                    wgpu::BindGroupEntry {
287                        binding: 0,
288                        resource: data_buffer.as_entire_binding(),
289                    },
290                    wgpu::BindGroupEntry {
291                        binding: 1,
292                        resource: shifts_buffer.as_entire_binding(),
293                    },
294                ],
295            });
296        let mut encoder = self.device.create_command_encoder(
297            &wgpu::CommandEncoderDescriptor { label: None },
298        );
299        {
300            let mut cpass =
301                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
302                    label: None,
303                    timestamp_writes: None,
304                });
305            cpass.set_pipeline(&self.coset_shift_pipeline);
306            cpass.set_bind_group(0, &bind_group, &[]);
307            cpass.dispatch_workgroups(
308                num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
309                1,
310                1,
311            );
312        }
313        self.queue.submit(Some(encoder.finish()));
314    }
315
316    pub fn execute_pointwise_poly(
317        &self,
318        a_buf: &wgpu::Buffer,
319        b_buf: &wgpu::Buffer,
320        c_buf: &wgpu::Buffer,
321        h_buf: &wgpu::Buffer,
322        z_invs_buf: &wgpu::Buffer,
323        num_elements: u32,
324    ) {
325        let bind_group =
326            self.device.create_bind_group(&wgpu::BindGroupDescriptor {
327                label: Some("Pointwise Poly Bind Group"),
328                layout: &self.pointwise_poly_bind_group_layout,
329                entries: &[
330                    wgpu::BindGroupEntry {
331                        binding: 0,
332                        resource: a_buf.as_entire_binding(),
333                    },
334                    wgpu::BindGroupEntry {
335                        binding: 1,
336                        resource: b_buf.as_entire_binding(),
337                    },
338                    wgpu::BindGroupEntry {
339                        binding: 2,
340                        resource: c_buf.as_entire_binding(),
341                    },
342                    wgpu::BindGroupEntry {
343                        binding: 3,
344                        resource: h_buf.as_entire_binding(),
345                    },
346                    wgpu::BindGroupEntry {
347                        binding: 4,
348                        resource: z_invs_buf.as_entire_binding(),
349                    },
350                ],
351            });
352        let mut encoder = self.device.create_command_encoder(
353            &wgpu::CommandEncoderDescriptor { label: None },
354        );
355        {
356            let mut cpass =
357                encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
358                    label: None,
359                    timestamp_writes: None,
360                });
361            cpass.set_pipeline(&self.pointwise_poly_pipeline);
362            cpass.set_bind_group(0, &bind_group, &[]);
363            cpass.dispatch_workgroups(
364                num_elements.div_ceil(C::SCALAR_WORKGROUP_SIZE),
365                1,
366                1,
367            );
368        }
369        self.queue.submit(Some(encoder.finish()));
370    }
371}