snarkvm_algorithms/msm/variable_base/
batched.rs1use snarkvm_curves::{AffineCurve, ProjectiveCurve};
17use snarkvm_fields::{Field, One, PrimeField, Zero};
18use snarkvm_utilities::{BigInteger, BitIteratorBE, cfg_into_iter};
19
20#[cfg(not(feature = "serial"))]
21use rayon::prelude::*;
22
23#[cfg(target_arch = "x86_64")]
24use crate::{prefetch_slice, prefetch_slice_write};
25
26#[derive(Copy, Clone, Debug)]
27pub struct BucketPosition {
28 pub bucket_index: u32,
29 pub scalar_index: u32,
30}
31
32impl Eq for BucketPosition {}
33
34impl PartialEq for BucketPosition {
35 fn eq(&self, other: &Self) -> bool {
36 self.bucket_index == other.bucket_index
37 }
38}
39
40impl Ord for BucketPosition {
41 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
42 self.bucket_index.cmp(&other.bucket_index)
43 }
44}
45
46impl PartialOrd for BucketPosition {
47 #[allow(clippy::non_canonical_partial_ord_impl)]
48 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
49 self.bucket_index.partial_cmp(&other.bucket_index)
50 }
51}
52
53#[inline]
56const fn batch_size(msm_size: usize) -> usize {
57 if cfg!(target_arch = "x86_64") && msm_size < 500_000 {
67 300
71 } else {
72 3000
76 }
77}
78
79#[inline]
83fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) {
84 let mut inversion_tmp = G::BaseField::one();
85 let half = G::BaseField::half();
86
87 #[cfg(target_arch = "x86_64")]
88 let mut prefetch_iter = index.iter();
89 #[cfg(target_arch = "x86_64")]
90 prefetch_iter.next();
91
92 for (idx, idy) in index.iter() {
94 #[cfg(target_arch = "x86_64")]
95 prefetch_slice!(G, bases, bases, prefetch_iter);
96
97 let (a, b) = if idx < idy {
98 let (x, y) = bases.split_at_mut(*idy as usize);
99 (&mut x[*idx as usize], &mut y[0])
100 } else {
101 let (x, y) = bases.split_at_mut(*idx as usize);
102 (&mut y[0], &mut x[*idy as usize])
103 };
104 G::batch_add_loop_1(a, b, &half, &mut inversion_tmp);
105 }
106
107 inversion_tmp = inversion_tmp.inverse().unwrap(); #[cfg(target_arch = "x86_64")]
110 let mut prefetch_iter = index.iter().rev();
111 #[cfg(target_arch = "x86_64")]
112 prefetch_iter.next();
113
114 for (idx, idy) in index.iter().rev() {
115 #[cfg(target_arch = "x86_64")]
116 prefetch_slice!(G, bases, bases, prefetch_iter);
117
118 let (a, b) = if idx < idy {
119 let (x, y) = bases.split_at_mut(*idy as usize);
120 (&mut x[*idx as usize], y[0])
121 } else {
122 let (x, y) = bases.split_at_mut(*idx as usize);
123 (&mut y[0], x[*idy as usize])
124 };
125 G::batch_add_loop_2(a, b, &mut inversion_tmp);
126 }
127}
128
129#[inline]
137fn batch_add_write<G: AffineCurve>(
138 bases: &[G],
139 index: &[(u32, u32)],
140 addition_result: &mut Vec<G>,
141 scratch_space: &mut Vec<Option<G>>,
142) {
143 let mut inversion_tmp = G::BaseField::one();
144 let half = G::BaseField::half();
145
146 #[cfg(target_arch = "x86_64")]
147 let mut prefetch_iter = index.iter();
148 #[cfg(target_arch = "x86_64")]
149 prefetch_iter.next();
150
151 for (idx, idy) in index.iter() {
153 #[cfg(target_arch = "x86_64")]
154 prefetch_slice_write!(G, bases, bases, prefetch_iter);
155
156 if *idy == !0u32 {
157 addition_result.push(bases[*idx as usize]);
158 scratch_space.push(None);
159 } else {
160 let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]);
161 G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp);
162 addition_result.push(a);
163 scratch_space.push(Some(b));
164 }
165 }
166
167 inversion_tmp = inversion_tmp.inverse().unwrap(); for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) {
170 if let Some(b) = op_b {
171 G::batch_add_loop_2(a, *b, &mut inversion_tmp);
172 }
173 }
174 scratch_space.clear();
175}
176
177#[inline]
178pub(super) fn batch_add<G: AffineCurve>(
179 num_buckets: usize,
180 bases: &[G],
181 bucket_positions: &mut [BucketPosition],
182) -> Vec<G> {
183 assert!(bases.len() >= bucket_positions.len());
184 assert!(!bases.is_empty());
185
186 let batch_size = batch_size(bases.len());
188
189 bucket_positions.sort_unstable();
191
192 let mut num_scalars = bucket_positions.len();
193 let mut all_ones = true;
194 let mut new_scalar_length = 0;
195 let mut global_counter = 0;
196 let mut local_counter = 1;
197 let mut number_of_bases_in_batch = 0;
198
199 let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size);
200 let mut new_bases = Vec::with_capacity(bases.len());
201 let mut scratch_space = Vec::with_capacity(batch_size / 2);
202
203 while global_counter < num_scalars {
206 let current_bucket = bucket_positions[global_counter].bucket_index;
207 while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket {
208 global_counter += 1;
209 local_counter += 1;
210 }
211 if current_bucket >= num_buckets as u32 {
212 local_counter = 1;
213 } else if local_counter > 1 {
214 if local_counter > 2 {
216 all_ones = false;
217 }
218 let is_odd = local_counter % 2 == 1;
219 let half = local_counter / 2;
220 for i in 0..half {
221 instr.push((
222 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
223 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
224 ));
225 bucket_positions[new_scalar_length + i] =
226 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 };
227 }
228 if is_odd {
229 instr.push((bucket_positions[global_counter].scalar_index, !0u32));
230 bucket_positions[new_scalar_length + half] =
231 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 };
232 }
233 new_scalar_length += half + (local_counter % 2);
235 number_of_bases_in_batch += half;
236 local_counter = 1;
237
238 if number_of_bases_in_batch >= batch_size / 2 {
241 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
244
245 instr.clear();
246 number_of_bases_in_batch = 0;
247 }
248 } else {
249 instr.push((bucket_positions[global_counter].scalar_index, !0u32));
250 bucket_positions[new_scalar_length] =
251 BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 };
252 new_scalar_length += 1;
253 }
254 global_counter += 1;
255 }
256 if !instr.is_empty() {
257 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
258 instr.clear();
259 }
260 global_counter = 0;
261 number_of_bases_in_batch = 0;
262 local_counter = 1;
263 num_scalars = new_scalar_length;
264 new_scalar_length = 0;
265
266 while !all_ones {
268 all_ones = true;
269 while global_counter < num_scalars {
270 let current_bucket = bucket_positions[global_counter].bucket_index;
271 while global_counter + 1 < num_scalars
272 && bucket_positions[global_counter + 1].bucket_index == current_bucket
273 {
274 global_counter += 1;
275 local_counter += 1;
276 }
277 if current_bucket >= num_buckets as u32 {
278 local_counter = 1;
279 } else if local_counter > 1 {
280 if local_counter != 2 {
282 all_ones = false;
283 }
284 let is_odd = local_counter % 2 == 1;
285 let half = local_counter / 2;
286 for i in 0..half {
287 instr.push((
288 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
289 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
290 ));
291 bucket_positions[new_scalar_length + i] =
292 bucket_positions[global_counter - (local_counter - 1) + 2 * i];
293 }
294 if is_odd {
295 bucket_positions[new_scalar_length + half] = bucket_positions[global_counter];
296 }
297 new_scalar_length += half + (local_counter % 2);
299 number_of_bases_in_batch += half;
300 local_counter = 1;
301
302 if number_of_bases_in_batch >= batch_size / 2 {
303 batch_add_in_place_same_slice(&mut new_bases, &instr);
304 instr.clear();
305 number_of_bases_in_batch = 0;
306 }
307 } else {
308 bucket_positions[new_scalar_length] = bucket_positions[global_counter];
309 new_scalar_length += 1;
310 }
311 global_counter += 1;
312 }
313 if !instr.is_empty() {
316 batch_add_in_place_same_slice(&mut new_bases, &instr);
317 instr.clear();
318 }
319 global_counter = 0;
320 number_of_bases_in_batch = 0;
321 local_counter = 1;
322 num_scalars = new_scalar_length;
323 new_scalar_length = 0;
324 }
325
326 let mut res = vec![Zero::zero(); num_buckets];
327 for bucket_position in bucket_positions.iter().take(num_scalars) {
328 res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize];
329 }
330 res
331}
332
333#[inline]
334fn batched_window<G: AffineCurve>(
335 bases: &[G],
336 scalars: &[<G::ScalarField as PrimeField>::BigInteger],
337 w_start: usize,
338 c: usize,
339) -> (G::Projective, usize) {
340 let window_size = if (w_start % c) != 0 { w_start % c } else { c };
342 let num_buckets = (1 << window_size) - 1;
343
344 let mut bucket_positions: Vec<_> = scalars
345 .iter()
346 .enumerate()
347 .map(|(scalar_index, &scalar)| {
348 let mut scalar = scalar;
349
350 scalar.divn(w_start as u32);
352
353 let scalar = (scalar.as_ref()[0] % (1 << c)) as i32;
355
356 BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 }
357 })
358 .collect();
359
360 let buckets = batch_add(num_buckets, bases, &mut bucket_positions);
361
362 let mut res = G::Projective::zero();
363 let mut running_sum = G::Projective::zero();
364 for b in buckets.into_iter().rev() {
365 running_sum.add_assign_mixed(&b);
366 res += &running_sum;
367 }
368
369 (res, window_size)
370}
371
372pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
373 if bases.len() < 15 {
374 let num_bits = G::ScalarField::size_in_bits();
375 let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64;
376 let mut bits =
377 scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>();
378 let mut sum = G::Projective::zero();
379
380 let mut encountered_one = false;
381 for _ in 0..num_bits {
382 if encountered_one {
383 sum.double_in_place();
384 }
385 for (bits, base) in bits.iter_mut().zip(bases) {
386 if let Some(true) = bits.next() {
387 sum.add_assign_mixed(base);
388 encountered_one = true;
389 }
390 }
391 }
392 debug_assert!(bits.iter_mut().all(|b| b.next().is_none()));
393 sum
394 } else {
395 let c = match scalars.len() < 32 {
397 true => 1,
398 false => crate::msm::ln_without_floats(scalars.len()) + 2,
399 };
400
401 let num_bits = <G::ScalarField as PrimeField>::size_in_bits();
402
403 let window_sums: Vec<_> =
407 cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect();
408
409 let (lowest, window_sums) = window_sums.split_first().unwrap();
411
412 window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| {
414 total += sum_i;
415 for _ in 0..*window_size {
416 total.double_in_place();
417 }
418 total
419 }) + lowest.0
420 }
421}