1use wgpu::util::DeviceExt;
18
19use super::curve::GpuCurve;
20use super::{GpuContext, HPolyBuffers, compute_pass};
21
22impl<C: GpuCurve> GpuContext<C> {
23 pub fn execute_h_pipeline(&self, bufs: &HPolyBuffers<'_>, n: u32) {
24 let mut encoder = self.device.create_command_encoder(
25 &wgpu::CommandEncoderDescriptor {
26 label: Some("H Pipeline Encoder"),
27 },
28 );
29
30 let mont_bg = |buf: &wgpu::Buffer| {
31 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
32 label: Some("Montgomery BG"),
33 layout: &self.montgomery_bind_group_layout,
34 entries: &[wgpu::BindGroupEntry {
35 binding: 0,
36 resource: buf.as_entire_binding(),
37 }],
38 })
39 };
40
41 let ntt_bg = |data: &wgpu::Buffer, tw: &wgpu::Buffer| {
42 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
43 label: Some("NTT BG"),
44 layout: &self.ntt_bind_group_layout,
45 entries: &[
46 wgpu::BindGroupEntry {
47 binding: 0,
48 resource: data.as_entire_binding(),
49 },
50 wgpu::BindGroupEntry {
51 binding: 1,
52 resource: tw.as_entire_binding(),
53 },
54 ],
55 })
56 };
57
58 let fused_shift_bg = |shifts: &wgpu::Buffer| {
59 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
60 label: Some("NTT Fused Shift BG"),
61 layout: &self.ntt_fused_shift_bgl,
62 entries: &[wgpu::BindGroupEntry {
63 binding: 0,
64 resource: shifts.as_entire_binding(),
65 }],
66 })
67 };
68
69 let pointwise_fused_bg =
70 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
71 label: Some("Pointwise Fused BG"),
72 layout: &self.pointwise_fused_bind_group_layout,
73 entries: &[
74 wgpu::BindGroupEntry {
75 binding: 0,
76 resource: bufs.a.as_entire_binding(),
77 },
78 wgpu::BindGroupEntry {
79 binding: 1,
80 resource: bufs.b.as_entire_binding(),
81 },
82 wgpu::BindGroupEntry {
83 binding: 2,
84 resource: bufs.c.as_entire_binding(),
85 },
86 wgpu::BindGroupEntry {
87 binding: 3,
88 resource: bufs.z_invs.as_entire_binding(),
89 },
90 ],
91 });
92
93 #[cfg(feature = "profiling")]
94 let mut profiler_guard = self.profiler.lock().unwrap();
95 #[cfg(feature = "profiling")]
96 let mut scope = profiler_guard.scope("h_pipeline", &mut encoder);
97
98 let mut param_updates = Vec::new();
100
101 for bg in [
103 mont_bg(bufs.a),
104 mont_bg(bufs.b),
105 mont_bg(bufs.c),
106 mont_bg(bufs.twiddles_inv),
107 mont_bg(bufs.twiddles_fwd),
108 mont_bg(bufs.shifts),
109 mont_bg(bufs.inv_shifts),
110 mont_bg(bufs.z_invs),
111 ] {
112 let mut pass = compute_pass!(scope, encoder, "to_montgomery");
113 pass.set_pipeline(&self.to_montgomery_pipeline);
114 pass.set_bind_group(0, &bg, &[]);
115 pass.dispatch_workgroups(
116 n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
117 1,
118 1,
119 );
120 }
121
122 macro_rules! encode_ntt {
123 (
124 $label:expr,
125 $data_buf:expr,
126 $tw_buf:expr,
127 $is_fused_shift:expr,
128 $shifts_buf:expr,
129 $is_h_fused_pointwise:expr
130 ) => {
131 if n <= C::NTT_TILE_SIZE {
132 let bg = ntt_bg($data_buf, $tw_buf);
133 if $is_h_fused_pointwise {
134 let shifts_group1 =
135 fused_shift_bg($shifts_buf.unwrap());
136 let mut pass = compute_pass!(
137 scope,
138 encoder,
139 concat!($label, "_fused_h")
140 );
141 pass.set_pipeline(
142 &self.ntt_tile_fused_pointwise_pipeline,
143 );
144 pass.set_bind_group(0, &bg, &[]);
145 pass.set_bind_group(1, &shifts_group1, &[]);
146 pass.set_bind_group(2, &pointwise_fused_bg, &[]);
147 pass.dispatch_workgroups(
148 n.div_ceil(C::NTT_TILE_SIZE),
149 1,
150 1,
151 );
152 } else if $is_fused_shift {
153 let shifts_group1 =
154 fused_shift_bg($shifts_buf.unwrap());
155 let mut pass = compute_pass!(
156 scope,
157 encoder,
158 concat!($label, "_fused")
159 );
160 pass.set_pipeline(&self.ntt_fused_pipeline);
161 pass.set_bind_group(0, &bg, &[]);
162 pass.set_bind_group(1, &shifts_group1, &[]);
163 pass.dispatch_workgroups(
164 n.div_ceil(C::NTT_TILE_SIZE),
165 1,
166 1,
167 );
168 } else {
169 let mut pass = compute_pass!(scope, encoder, $label);
170 pass.set_pipeline(&self.ntt_pipeline);
171 pass.set_bind_group(0, &bg, &[]);
172 pass.dispatch_workgroups(
173 n.div_ceil(C::NTT_TILE_SIZE),
174 1,
175 1,
176 );
177 }
178 } else {
179 let mut log_n = 0u32;
180 let mut m = n;
181 while m > 1 {
182 log_n += 1;
183 m >>= 1;
184 }
185
186 let mut stage_params = [n, 0, log_n, 0];
187 let params_buf = self.device.create_buffer_init(
188 &wgpu::util::BufferInitDescriptor {
189 label: Some("NTT Params Buffer"),
190 contents: bytemuck::cast_slice(&stage_params),
191 usage: wgpu::BufferUsages::UNIFORM
192 | wgpu::BufferUsages::COPY_DST,
193 },
194 );
195
196 let bg = self.device.create_bind_group(
197 &wgpu::BindGroupDescriptor {
198 label: Some("NTT Global BG"),
199 layout: &self.ntt_params_bind_group_layout,
200 entries: &[
201 wgpu::BindGroupEntry {
202 binding: 0,
203 resource: $data_buf.as_entire_binding(),
204 },
205 wgpu::BindGroupEntry {
206 binding: 1,
207 resource: $tw_buf.as_entire_binding(),
208 },
209 wgpu::BindGroupEntry {
210 binding: 2,
211 resource: params_buf.as_entire_binding(),
212 },
213 ],
214 },
215 );
216
217 if $is_h_fused_pointwise {
218 let shifts_group1 =
219 fused_shift_bg($shifts_buf.unwrap());
220 let mut pass = compute_pass!(
221 scope,
222 encoder,
223 concat!($label, "_bitreverse_fused_h")
224 );
225 pass.set_pipeline(
226 &self.ntt_bitreverse_fused_pointwise_pipeline,
227 );
228 pass.set_bind_group(0, &bg, &[]);
229 pass.set_bind_group(1, &shifts_group1, &[]);
230 pass.set_bind_group(2, &pointwise_fused_bg, &[]);
231 pass.dispatch_workgroups(
232 n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
233 1,
234 1,
235 );
236 } else {
237 let mut pass = compute_pass!(
238 scope,
239 encoder,
240 concat!($label, "_bitreverse")
241 );
242 pass.set_pipeline(&self.ntt_bitreverse_pipeline);
243 pass.set_bind_group(0, &bg, &[]);
244 pass.dispatch_workgroups(
245 n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
246 1,
247 1,
248 );
249 }
250
251 let mut half_len = 1u32;
252 if (log_n & 1) == 1 {
253 stage_params[1] = half_len;
254 let update_buf = self.device.create_buffer_init(
255 &wgpu::util::BufferInitDescriptor {
256 label: Some("NTT Params Update"),
257 contents: bytemuck::cast_slice(&stage_params),
258 usage: wgpu::BufferUsages::COPY_SRC,
259 },
260 );
261 encoder.copy_buffer_to_buffer(
262 &update_buf,
263 0,
264 ¶ms_buf,
265 0,
266 16,
267 );
268 param_updates.push(update_buf);
269
270 let mut pass = compute_pass!(
271 scope,
272 encoder,
273 concat!($label, "_stage")
274 );
275 pass.set_pipeline(&self.ntt_global_stage_pipeline);
276 pass.set_bind_group(0, &bg, &[]);
277 pass.dispatch_workgroups(
278 (n / 2).div_ceil(C::SCALAR_WORKGROUP_SIZE),
279 1,
280 1,
281 );
282
283 half_len = 2;
284 }
285
286 while half_len < n {
287 stage_params[1] = half_len;
288 let update_buf = self.device.create_buffer_init(
289 &wgpu::util::BufferInitDescriptor {
290 label: Some("NTT Params Update"),
291 contents: bytemuck::cast_slice(&stage_params),
292 usage: wgpu::BufferUsages::COPY_SRC,
293 },
294 );
295 encoder.copy_buffer_to_buffer(
296 &update_buf,
297 0,
298 ¶ms_buf,
299 0,
300 16,
301 );
302 param_updates.push(update_buf);
303
304 let mut pass = compute_pass!(
305 scope,
306 encoder,
307 concat!($label, "_stage_radix4")
308 );
309 pass.set_pipeline(
310 &self.ntt_global_stage_radix4_pipeline,
311 );
312 pass.set_bind_group(0, &bg, &[]);
313 pass.dispatch_workgroups(
314 (n / 4).div_ceil(C::SCALAR_WORKGROUP_SIZE),
315 1,
316 1,
317 );
318
319 half_len <<= 2;
320 }
321
322 if $is_fused_shift || $is_h_fused_pointwise {
323 let shift_bg = self.device.create_bind_group(
324 &wgpu::BindGroupDescriptor {
325 label: Some("Coset Shift BG"),
326 layout: &self.coset_shift_bind_group_layout,
327 entries: &[
328 wgpu::BindGroupEntry {
329 binding: 0,
330 resource: $data_buf.as_entire_binding(),
331 },
332 wgpu::BindGroupEntry {
333 binding: 1,
334 resource: $shifts_buf
335 .unwrap()
336 .as_entire_binding(),
337 },
338 ],
339 },
340 );
341 let mut pass = compute_pass!(
342 scope,
343 encoder,
344 concat!($label, "_shift")
345 );
346 pass.set_pipeline(&self.coset_shift_pipeline);
347 pass.set_bind_group(0, &shift_bg, &[]);
348 pass.dispatch_workgroups(
349 n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
350 1,
351 1,
352 );
353 }
354
355 param_updates.push(params_buf);
356 }
357 };
358 }
359
360 encode_ntt!(
362 "intt_a",
363 bufs.a,
364 bufs.twiddles_inv,
365 true,
366 Some(bufs.shifts),
367 false
368 );
369 encode_ntt!(
370 "intt_b",
371 bufs.b,
372 bufs.twiddles_inv,
373 true,
374 Some(bufs.shifts),
375 false
376 );
377 encode_ntt!(
378 "intt_c",
379 bufs.c,
380 bufs.twiddles_inv,
381 true,
382 Some(bufs.shifts),
383 false
384 );
385
386 encode_ntt!(
388 "ntt_a",
389 bufs.a,
390 bufs.twiddles_fwd,
391 false,
392 None::<&wgpu::Buffer>,
393 false
394 );
395 encode_ntt!(
396 "ntt_b",
397 bufs.b,
398 bufs.twiddles_fwd,
399 false,
400 None::<&wgpu::Buffer>,
401 false
402 );
403 encode_ntt!(
404 "ntt_c",
405 bufs.c,
406 bufs.twiddles_fwd,
407 false,
408 None::<&wgpu::Buffer>,
409 false
410 );
411
412 encode_ntt!(
414 "intt_h",
415 bufs.h,
416 bufs.twiddles_inv,
417 false,
418 Some(bufs.inv_shifts),
419 true
420 );
421
422 {
424 let bg = mont_bg(bufs.h);
425 let mut pass = compute_pass!(scope, encoder, "from_montgomery_h");
426 pass.set_pipeline(&self.from_montgomery_pipeline);
427 pass.set_bind_group(0, &bg, &[]);
428 pass.dispatch_workgroups(
429 n.div_ceil(C::SCALAR_WORKGROUP_SIZE),
430 1,
431 1,
432 );
433 }
434
435 #[cfg(feature = "profiling")]
436 {
437 drop(scope);
438 profiler_guard.resolve_queries(&mut encoder);
439 }
440
441 self.queue.submit(Some(encoder.finish()));
442
443 drop(param_updates);
446 }
447}