snarkvm_algorithms/msm/variable_base/
batched.rs1use snarkvm_curves::{AffineCurve, ProjectiveCurve};
17use snarkvm_fields::{Field, One, PrimeField, Zero};
18use snarkvm_utilities::{BigInteger, BitIteratorBE, cfg_into_iter};
19
20#[cfg(not(feature = "serial"))]
21use rayon::prelude::*;
22
23#[cfg(target_arch = "x86_64")]
24use crate::{prefetch_slice, prefetch_slice_write};
25
26#[derive(Copy, Clone, Debug)]
27pub struct BucketPosition {
28 pub bucket_index: u32,
29 pub scalar_index: u32,
30}
31
32impl Eq for BucketPosition {}
33
34impl PartialEq for BucketPosition {
35 fn eq(&self, other: &Self) -> bool {
36 self.bucket_index == other.bucket_index
37 }
38}
39
40impl Ord for BucketPosition {
41 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
42 self.bucket_index.cmp(&other.bucket_index)
43 }
44}
45
46impl PartialOrd for BucketPosition {
47 #[allow(clippy::non_canonical_partial_ord_impl)]
48 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
49 self.bucket_index.partial_cmp(&other.bucket_index)
50 }
51}
52
53#[inline]
56const fn batch_size(msm_size: usize) -> usize {
57 if cfg!(target_arch = "x86_64") && msm_size < 500_000 {
65 300
68 } else {
69 3000
72 }
73}
74
75#[inline]
78fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) {
79 let mut inversion_tmp = G::BaseField::one();
80 let half = G::BaseField::half();
81
82 #[cfg(target_arch = "x86_64")]
83 let mut prefetch_iter = index.iter();
84 #[cfg(target_arch = "x86_64")]
85 prefetch_iter.next();
86
87 for (idx, idy) in index.iter() {
89 #[cfg(target_arch = "x86_64")]
90 prefetch_slice!(G, bases, bases, prefetch_iter);
91
92 let (a, b) = if idx < idy {
93 let (x, y) = bases.split_at_mut(*idy as usize);
94 (&mut x[*idx as usize], &mut y[0])
95 } else {
96 let (x, y) = bases.split_at_mut(*idx as usize);
97 (&mut y[0], &mut x[*idy as usize])
98 };
99 G::batch_add_loop_1(a, b, &half, &mut inversion_tmp);
100 }
101
102 inversion_tmp = inversion_tmp.inverse().unwrap(); #[cfg(target_arch = "x86_64")]
105 let mut prefetch_iter = index.iter().rev();
106 #[cfg(target_arch = "x86_64")]
107 prefetch_iter.next();
108
109 for (idx, idy) in index.iter().rev() {
110 #[cfg(target_arch = "x86_64")]
111 prefetch_slice!(G, bases, bases, prefetch_iter);
112
113 let (a, b) = if idx < idy {
114 let (x, y) = bases.split_at_mut(*idy as usize);
115 (&mut x[*idx as usize], y[0])
116 } else {
117 let (x, y) = bases.split_at_mut(*idx as usize);
118 (&mut y[0], x[*idy as usize])
119 };
120 G::batch_add_loop_2(a, b, &mut inversion_tmp);
121 }
122}
123
124#[inline]
131fn batch_add_write<G: AffineCurve>(
132 bases: &[G],
133 index: &[(u32, u32)],
134 addition_result: &mut Vec<G>,
135 scratch_space: &mut Vec<Option<G>>,
136) {
137 let mut inversion_tmp = G::BaseField::one();
138 let half = G::BaseField::half();
139
140 #[cfg(target_arch = "x86_64")]
141 let mut prefetch_iter = index.iter();
142 #[cfg(target_arch = "x86_64")]
143 prefetch_iter.next();
144
145 for (idx, idy) in index.iter() {
147 #[cfg(target_arch = "x86_64")]
148 prefetch_slice_write!(G, bases, bases, prefetch_iter);
149
150 if *idy == !0u32 {
151 addition_result.push(bases[*idx as usize]);
152 scratch_space.push(None);
153 } else {
154 let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]);
155 G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp);
156 addition_result.push(a);
157 scratch_space.push(Some(b));
158 }
159 }
160
161 inversion_tmp = inversion_tmp.inverse().unwrap(); for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) {
164 if let Some(b) = op_b {
165 G::batch_add_loop_2(a, *b, &mut inversion_tmp);
166 }
167 }
168 scratch_space.clear();
169}
170
171#[inline]
172pub(super) fn batch_add<G: AffineCurve>(
173 num_buckets: usize,
174 bases: &[G],
175 bucket_positions: &mut [BucketPosition],
176) -> Vec<G> {
177 assert!(bases.len() >= bucket_positions.len());
178 assert!(!bases.is_empty());
179
180 let batch_size = batch_size(bases.len());
182
183 bucket_positions.sort_unstable();
185
186 let mut num_scalars = bucket_positions.len();
187 let mut all_ones = true;
188 let mut new_scalar_length = 0;
189 let mut global_counter = 0;
190 let mut local_counter = 1;
191 let mut number_of_bases_in_batch = 0;
192
193 let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size);
194 let mut new_bases = Vec::with_capacity(bases.len());
195 let mut scratch_space = Vec::with_capacity(batch_size / 2);
196
197 while global_counter < num_scalars {
199 let current_bucket = bucket_positions[global_counter].bucket_index;
200 while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket {
201 global_counter += 1;
202 local_counter += 1;
203 }
204 if current_bucket >= num_buckets as u32 {
205 local_counter = 1;
206 } else if local_counter > 1 {
207 if local_counter > 2 {
209 all_ones = false;
210 }
211 let is_odd = local_counter % 2 == 1;
212 let half = local_counter / 2;
213 for i in 0..half {
214 instr.push((
215 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
216 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
217 ));
218 bucket_positions[new_scalar_length + i] =
219 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 };
220 }
221 if is_odd {
222 instr.push((bucket_positions[global_counter].scalar_index, !0u32));
223 bucket_positions[new_scalar_length + half] =
224 BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 };
225 }
226 new_scalar_length += half + (local_counter % 2);
228 number_of_bases_in_batch += half;
229 local_counter = 1;
230
231 if number_of_bases_in_batch >= batch_size / 2 {
233 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
236
237 instr.clear();
238 number_of_bases_in_batch = 0;
239 }
240 } else {
241 instr.push((bucket_positions[global_counter].scalar_index, !0u32));
242 bucket_positions[new_scalar_length] =
243 BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 };
244 new_scalar_length += 1;
245 }
246 global_counter += 1;
247 }
248 if !instr.is_empty() {
249 batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
250 instr.clear();
251 }
252 global_counter = 0;
253 number_of_bases_in_batch = 0;
254 local_counter = 1;
255 num_scalars = new_scalar_length;
256 new_scalar_length = 0;
257
258 while !all_ones {
260 all_ones = true;
261 while global_counter < num_scalars {
262 let current_bucket = bucket_positions[global_counter].bucket_index;
263 while global_counter + 1 < num_scalars
264 && bucket_positions[global_counter + 1].bucket_index == current_bucket
265 {
266 global_counter += 1;
267 local_counter += 1;
268 }
269 if current_bucket >= num_buckets as u32 {
270 local_counter = 1;
271 } else if local_counter > 1 {
272 if local_counter != 2 {
274 all_ones = false;
275 }
276 let is_odd = local_counter % 2 == 1;
277 let half = local_counter / 2;
278 for i in 0..half {
279 instr.push((
280 bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
281 bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
282 ));
283 bucket_positions[new_scalar_length + i] =
284 bucket_positions[global_counter - (local_counter - 1) + 2 * i];
285 }
286 if is_odd {
287 bucket_positions[new_scalar_length + half] = bucket_positions[global_counter];
288 }
289 new_scalar_length += half + (local_counter % 2);
291 number_of_bases_in_batch += half;
292 local_counter = 1;
293
294 if number_of_bases_in_batch >= batch_size / 2 {
295 batch_add_in_place_same_slice(&mut new_bases, &instr);
296 instr.clear();
297 number_of_bases_in_batch = 0;
298 }
299 } else {
300 bucket_positions[new_scalar_length] = bucket_positions[global_counter];
301 new_scalar_length += 1;
302 }
303 global_counter += 1;
304 }
305 if !instr.is_empty() {
307 batch_add_in_place_same_slice(&mut new_bases, &instr);
308 instr.clear();
309 }
310 global_counter = 0;
311 number_of_bases_in_batch = 0;
312 local_counter = 1;
313 num_scalars = new_scalar_length;
314 new_scalar_length = 0;
315 }
316
317 let mut res = vec![Zero::zero(); num_buckets];
318 for bucket_position in bucket_positions.iter().take(num_scalars) {
319 res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize];
320 }
321 res
322}
323
324#[inline]
325fn batched_window<G: AffineCurve>(
326 bases: &[G],
327 scalars: &[<G::ScalarField as PrimeField>::BigInteger],
328 w_start: usize,
329 c: usize,
330) -> (G::Projective, usize) {
331 let window_size = if (w_start % c) != 0 { w_start % c } else { c };
333 let num_buckets = (1 << window_size) - 1;
334
335 let mut bucket_positions: Vec<_> = scalars
336 .iter()
337 .enumerate()
338 .map(|(scalar_index, &scalar)| {
339 let mut scalar = scalar;
340
341 scalar.divn(w_start as u32);
343
344 let scalar = (scalar.as_ref()[0] % (1 << c)) as i32;
346
347 BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 }
348 })
349 .collect();
350
351 let buckets = batch_add(num_buckets, bases, &mut bucket_positions);
352
353 let mut res = G::Projective::zero();
354 let mut running_sum = G::Projective::zero();
355 for b in buckets.into_iter().rev() {
356 running_sum.add_assign_mixed(&b);
357 res += &running_sum;
358 }
359
360 (res, window_size)
361}
362
363pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
364 if bases.len() < 15 {
365 let num_bits = G::ScalarField::size_in_bits();
366 let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64;
367 let mut bits =
368 scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>();
369 let mut sum = G::Projective::zero();
370
371 let mut encountered_one = false;
372 for _ in 0..num_bits {
373 if encountered_one {
374 sum.double_in_place();
375 }
376 for (bits, base) in bits.iter_mut().zip(bases) {
377 if let Some(true) = bits.next() {
378 sum.add_assign_mixed(base);
379 encountered_one = true;
380 }
381 }
382 }
383 debug_assert!(bits.iter_mut().all(|b| b.next().is_none()));
384 sum
385 } else {
386 let c = match scalars.len() < 32 {
388 true => 1,
389 false => crate::msm::ln_without_floats(scalars.len()) + 2,
390 };
391
392 let num_bits = <G::ScalarField as PrimeField>::size_in_bits();
393
394 let window_sums: Vec<_> =
398 cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect();
399
400 let (lowest, window_sums) = window_sums.split_first().unwrap();
402
403 window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| {
405 total += sum_i;
406 for _ in 0..*window_size {
407 total.double_in_place();
408 }
409 total
410 }) + lowest.0
411 }
412}