1use wgpu::util::DeviceExt;
31
32use super::curve::GpuCurve;
33use super::{GpuContext, MsmBuffers, compute_pass};
34
35impl<C: GpuCurve> GpuContext<C> {
36 #[allow(clippy::too_many_arguments)]
37 pub fn execute_msm(
38 &self,
39 is_g2: bool,
40 bufs: &MsmBuffers<'_>,
41 num_active_buckets: u32,
42 num_dispatched: u32,
43 has_chunks: bool,
44 num_windows: u32,
45 skip_montgomery: bool,
46 ) {
47 let bases_buf = bufs.bases;
48 let base_indices_buf = bufs.base_indices;
49 let bucket_pointers_buf = bufs.bucket_pointers;
50 let bucket_sizes_buf = bufs.bucket_sizes;
51 let aggregated_buckets_buf = bufs.aggregated_buckets;
52 let bucket_values_buf = bufs.bucket_values;
53 let window_starts_buf = bufs.window_starts;
54 let window_counts_buf = bufs.window_counts;
55 let window_sums_buf = bufs.window_sums;
56
57 let point_gpu_bytes: u64 = if is_g2 {
58 C::G2_GPU_BYTES as u64
59 } else {
60 C::G1_GPU_BYTES as u64
61 };
62
63 let intermediate_buf = if has_chunks {
67 Some(self.device.create_buffer(&wgpu::BufferDescriptor {
68 label: Some("MSM Intermediate Sub-Buckets"),
69 size: num_dispatched as u64 * point_gpu_bytes,
70 usage: wgpu::BufferUsages::STORAGE,
71 mapped_at_creation: false,
72 }))
73 } else {
74 None
75 };
76
77 let agg_output_buf =
80 intermediate_buf.as_ref().unwrap_or(aggregated_buckets_buf);
81
82 let agg_bind_group =
83 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
84 label: Some("MSM Agg Bind Group"),
85 layout: &self.msm_agg_bind_group_layout,
86 entries: &[
87 wgpu::BindGroupEntry {
88 binding: 0,
89 resource: bases_buf.as_entire_binding(),
90 },
91 wgpu::BindGroupEntry {
92 binding: 1,
93 resource: base_indices_buf.as_entire_binding(),
94 },
95 wgpu::BindGroupEntry {
96 binding: 2,
97 resource: bucket_pointers_buf.as_entire_binding(),
98 },
99 wgpu::BindGroupEntry {
100 binding: 3,
101 resource: bucket_sizes_buf.as_entire_binding(),
102 },
103 wgpu::BindGroupEntry {
104 binding: 4,
105 resource: agg_output_buf.as_entire_binding(),
106 },
107 wgpu::BindGroupEntry {
108 binding: 5,
109 resource: bucket_values_buf.as_entire_binding(),
110 },
111 ],
112 });
113
114 let mut encoder = self.device.create_command_encoder(
115 &wgpu::CommandEncoderDescriptor {
116 label: Some("MSM Encoder"),
117 },
118 );
119
120 #[cfg(feature = "profiling")]
121 let mut profiler_guard = self.profiler.lock().unwrap();
122 #[cfg(feature = "profiling")]
123 let mut scope = profiler_guard
124 .scope(if is_g2 { "msm_g2" } else { "msm_g1" }, &mut encoder);
125
126 if !skip_montgomery {
131 let mont_bind_group =
132 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
133 label: Some("MSM Bases Mont Bind Group"),
134 layout: &self.montgomery_bind_group_layout,
135 entries: &[wgpu::BindGroupEntry {
136 binding: 0,
137 resource: bases_buf.as_entire_binding(),
138 }],
139 });
140 let point_size: u64 = if is_g2 {
141 C::G2_GPU_BYTES as u64
142 } else {
143 C::G1_GPU_BYTES as u64
144 };
145 let num_bases = (bases_buf.size() / point_size) as u32;
146 let mut cpass =
147 compute_pass!(scope, encoder, "to_montgomery_bases");
148 cpass.set_pipeline(if is_g2 {
149 &self.msm_to_mont_g2_pipeline
150 } else {
151 &self.msm_to_mont_g1_pipeline
152 });
153 cpass.set_bind_group(0, &mont_bind_group, &[]);
154 cpass.dispatch_workgroups(
155 num_bases.div_ceil(C::MSM_WORKGROUP_SIZE),
156 1,
157 1,
158 );
159 }
160
161 {
162 let mut cpass = compute_pass!(scope, encoder, "bucket_aggregation");
163 cpass.set_pipeline(if is_g2 {
164 &self.msm_agg_g2_pipeline
165 } else {
166 &self.msm_agg_g1_pipeline
167 });
168 cpass.set_bind_group(0, &agg_bind_group, &[]);
169 cpass.dispatch_workgroups(
170 num_dispatched.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
171 1,
172 1,
173 );
174 }
175
176 if has_chunks {
179 let reduce_starts_buf = bufs
180 .reduce_starts
181 .expect("reduce_starts required when has_chunks");
182 let reduce_counts_buf = bufs
183 .reduce_counts
184 .expect("reduce_counts required when has_chunks");
185 let reduce_bind_group =
186 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
187 label: Some("MSM Reduce Sub-Buckets BG"),
188 layout: &self.msm_reduce_bind_group_layout,
189 entries: &[
190 wgpu::BindGroupEntry {
191 binding: 0,
192 resource: agg_output_buf.as_entire_binding(),
193 },
194 wgpu::BindGroupEntry {
195 binding: 1,
196 resource: reduce_starts_buf.as_entire_binding(),
197 },
198 wgpu::BindGroupEntry {
199 binding: 2,
200 resource: reduce_counts_buf.as_entire_binding(),
201 },
202 wgpu::BindGroupEntry {
203 binding: 3,
204 resource: aggregated_buckets_buf
205 .as_entire_binding(),
206 },
207 ],
208 });
209 let mut cpass = compute_pass!(scope, encoder, "reduce_sub_buckets");
210 cpass.set_pipeline(if is_g2 {
211 &self.msm_reduce_g2_pipeline
212 } else {
213 &self.msm_reduce_g1_pipeline
214 });
215 cpass.set_bind_group(0, &reduce_bind_group, &[]);
216 cpass.dispatch_workgroups(
217 num_active_buckets.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
218 1,
219 1,
220 );
221 }
222
223 let weight_values_buf = if has_chunks {
227 bufs.orig_bucket_values
228 .expect("orig_bucket_values required when has_chunks")
229 } else {
230 bucket_values_buf
231 };
232 {
233 let weight_bind_group =
234 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
235 label: Some(if is_g2 {
236 "MSM Weight G2 BG"
237 } else {
238 "MSM Weight G1 BG"
239 }),
240 layout: if is_g2 {
241 &self.msm_weight_g2_bind_group_layout
242 } else {
243 &self.msm_weight_g1_bind_group_layout
244 },
245 entries: &[
246 wgpu::BindGroupEntry {
247 binding: 0,
248 resource: aggregated_buckets_buf
249 .as_entire_binding(),
250 },
251 wgpu::BindGroupEntry {
252 binding: 1,
253 resource: weight_values_buf.as_entire_binding(),
254 },
255 ],
256 });
257 let mut cpass = compute_pass!(scope, encoder, "bucket_weighting");
258 cpass.set_pipeline(if is_g2 {
259 &self.msm_weight_g2_pipeline
260 } else {
261 &self.msm_weight_g1_pipeline
262 });
263 cpass.set_bind_group(0, &weight_bind_group, &[]);
264 cpass.dispatch_workgroups(
265 num_active_buckets.div_ceil(C::MSM_WORKGROUP_SIZE).max(1),
266 1,
267 1,
268 );
269 }
270
271 {
280 let chunks_per_window = if is_g2 {
281 C::G2_SUBSUM_CHUNKS_PER_WINDOW
282 } else {
283 C::G1_SUBSUM_CHUNKS_PER_WINDOW
284 };
285 let subsum_window_starts = if has_chunks {
286 bufs.orig_window_starts
287 .expect("orig_window_starts required when has_chunks")
288 } else {
289 window_starts_buf
290 };
291 let subsum_window_counts = if has_chunks {
292 bufs.orig_window_counts
293 .expect("orig_window_counts required when has_chunks")
294 } else {
295 window_counts_buf
296 };
297
298 let partial_sums_buf =
299 self.device.create_buffer(&wgpu::BufferDescriptor {
300 label: Some("MSM Partial Sums"),
301 size: (num_windows * chunks_per_window) as u64
302 * point_gpu_bytes,
303 usage: wgpu::BufferUsages::STORAGE,
304 mapped_at_creation: false,
305 });
306 let subsum_params: [u32; 4] = [chunks_per_window, 0, 0, 0];
307 let subsum_params_buf = self.device.create_buffer_init(
308 &wgpu::util::BufferInitDescriptor {
309 label: Some("Subsum Params"),
310 contents: bytemuck::cast_slice(&subsum_params),
311 usage: wgpu::BufferUsages::UNIFORM,
312 },
313 );
314
315 let phase1_bind_group =
316 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
317 label: Some("MSM Subsum Phase1 BG"),
318 layout: &self.msm_subsum_phase1_bind_group_layout,
319 entries: &[
320 wgpu::BindGroupEntry {
321 binding: 0,
322 resource: aggregated_buckets_buf
323 .as_entire_binding(),
324 },
325 wgpu::BindGroupEntry {
326 binding: 1,
327 resource: subsum_window_starts.as_entire_binding(),
328 },
329 wgpu::BindGroupEntry {
330 binding: 2,
331 resource: subsum_window_counts.as_entire_binding(),
332 },
333 wgpu::BindGroupEntry {
334 binding: 3,
335 resource: partial_sums_buf.as_entire_binding(),
336 },
337 wgpu::BindGroupEntry {
338 binding: 4,
339 resource: subsum_params_buf.as_entire_binding(),
340 },
341 ],
342 });
343
344 let phase2_bind_group =
345 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
346 label: Some("MSM Subsum Phase2 BG"),
347 layout: &self.msm_subsum_phase2_bind_group_layout,
348 entries: &[
349 wgpu::BindGroupEntry {
350 binding: 0,
351 resource: partial_sums_buf.as_entire_binding(),
352 },
353 wgpu::BindGroupEntry {
354 binding: 1,
355 resource: window_sums_buf.as_entire_binding(),
356 },
357 wgpu::BindGroupEntry {
358 binding: 2,
359 resource: subsum_params_buf.as_entire_binding(),
360 },
361 ],
362 });
363
364 {
366 let mut cpass =
367 compute_pass!(scope, encoder, "tree_reduction_ph1");
368 cpass.set_pipeline(if is_g2 {
369 &self.msm_subsum_phase1_g2_pipeline
370 } else {
371 &self.msm_subsum_phase1_g1_pipeline
372 });
373 cpass.set_bind_group(0, &phase1_bind_group, &[]);
374 cpass.dispatch_workgroups(
375 num_windows * chunks_per_window,
376 1,
377 1,
378 );
379 }
380
381 {
383 let mut cpass =
384 compute_pass!(scope, encoder, "tree_reduction_ph2");
385 cpass.set_pipeline(if is_g2 {
386 &self.msm_subsum_phase2_g2_pipeline
387 } else {
388 &self.msm_subsum_phase2_g1_pipeline
389 });
390 cpass.set_bind_group(0, &phase2_bind_group, &[]);
391 cpass.dispatch_workgroups(num_windows, 1, 1);
392 }
393 }
394
395 #[cfg(feature = "profiling")]
396 {
397 drop(scope);
398 profiler_guard.resolve_queries(&mut encoder);
399 }
400
401 self.queue.submit(Some(encoder.finish()));
402 }
403
404 pub fn convert_to_montgomery(&self, buf: &wgpu::Buffer, is_g2: bool) {
407 let mont_bind_group =
408 self.device.create_bind_group(&wgpu::BindGroupDescriptor {
409 label: Some("Convert To Montgomery BG"),
410 layout: &self.montgomery_bind_group_layout,
411 entries: &[wgpu::BindGroupEntry {
412 binding: 0,
413 resource: buf.as_entire_binding(),
414 }],
415 });
416 let point_size: u64 = if is_g2 {
417 C::G2_GPU_BYTES as u64
418 } else {
419 C::G1_GPU_BYTES as u64
420 };
421 let num_bases = (buf.size() / point_size) as u32;
422 let mut encoder = self.device.create_command_encoder(
423 &wgpu::CommandEncoderDescriptor {
424 label: Some("Convert To Montgomery Encoder"),
425 },
426 );
427 {
428 let mut cpass =
429 encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
430 label: Some("to_montgomery"),
431 timestamp_writes: None,
432 });
433 cpass.set_pipeline(if is_g2 {
434 &self.msm_to_mont_g2_pipeline
435 } else {
436 &self.msm_to_mont_g1_pipeline
437 });
438 cpass.set_bind_group(0, &mont_bind_group, &[]);
439 cpass.dispatch_workgroups(
440 num_bases.div_ceil(C::MSM_WORKGROUP_SIZE),
441 1,
442 1,
443 );
444 }
445 self.queue.submit(Some(encoder.finish()));
446 }
447}