snarkvm_algorithms/msm/variable_base/
batched.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use snarkvm_curves::{AffineCurve, ProjectiveCurve};
17use snarkvm_fields::{Field, One, PrimeField, Zero};
18use snarkvm_utilities::{BigInteger, BitIteratorBE, cfg_into_iter};
19
20#[cfg(not(feature = "serial"))]
21use rayon::prelude::*;
22
23#[cfg(target_arch = "x86_64")]
24use crate::{prefetch_slice, prefetch_slice_write};
25
26#[derive(Copy, Clone, Debug)]
27pub struct BucketPosition {
28    pub bucket_index: u32,
29    pub scalar_index: u32,
30}
31
32impl Eq for BucketPosition {}
33
34impl PartialEq for BucketPosition {
35    fn eq(&self, other: &Self) -> bool {
36        self.bucket_index == other.bucket_index
37    }
38}
39
40impl Ord for BucketPosition {
41    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
42        self.bucket_index.cmp(&other.bucket_index)
43    }
44}
45
46impl PartialOrd for BucketPosition {
47    #[allow(clippy::non_canonical_partial_ord_impl)]
48    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
49        self.bucket_index.partial_cmp(&other.bucket_index)
50    }
51}
52
53/// Returns a batch size of sufficient size to amortize the cost of an
54/// inversion, while attempting to reduce strain to the CPU cache.
55#[inline]
56const fn batch_size(msm_size: usize) -> usize {
57    // These values are determined empirically using performance benchmarks for
58    // BLS12-377 on Intel, AMD, and M1 machines. These values are determined by
59    // taking the L1 and L2 cache sizes and dividing them by the size of group
60    // elements (i.e. 96 bytes).
61    //
62    // As the algorithm itself requires caching additional values beyond the group
63    // elements, the ideal batch size is less than expected, to accommodate
64    // those values. In general, it was found that undershooting is better than
65    // overshooting this heuristic.
66    if cfg!(target_arch = "x86_64") && msm_size < 500_000 {
67        // Assumes an L1 cache size of 32KiB. Note that larger cache sizes
68        // are not negatively impacted by this value, however smaller L1 cache sizes
69        // are.
70        300
71    } else {
72        // Assumes an L2 cache size of 1MiB. Note that larger cache sizes
73        // are not negatively impacted by this value, however smaller L2 cache sizes
74        // are.
75        3000
76    }
77}
78
79/// If `(j, k)` is the `i`-th entry in `index`, then this method sets
80/// `bases[j] = bases[j] + bases[k]`. The state of `bases[k]` becomes
81/// unspecified.
82#[inline]
83fn batch_add_in_place_same_slice<G: AffineCurve>(bases: &mut [G], index: &[(u32, u32)]) {
84    let mut inversion_tmp = G::BaseField::one();
85    let half = G::BaseField::half();
86
87    #[cfg(target_arch = "x86_64")]
88    let mut prefetch_iter = index.iter();
89    #[cfg(target_arch = "x86_64")]
90    prefetch_iter.next();
91
92    // We run two loops over the data separated by an inversion
93    for (idx, idy) in index.iter() {
94        #[cfg(target_arch = "x86_64")]
95        prefetch_slice!(G, bases, bases, prefetch_iter);
96
97        let (a, b) = if idx < idy {
98            let (x, y) = bases.split_at_mut(*idy as usize);
99            (&mut x[*idx as usize], &mut y[0])
100        } else {
101            let (x, y) = bases.split_at_mut(*idx as usize);
102            (&mut y[0], &mut x[*idy as usize])
103        };
104        G::batch_add_loop_1(a, b, &half, &mut inversion_tmp);
105    }
106
107    inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
108
109    #[cfg(target_arch = "x86_64")]
110    let mut prefetch_iter = index.iter().rev();
111    #[cfg(target_arch = "x86_64")]
112    prefetch_iter.next();
113
114    for (idx, idy) in index.iter().rev() {
115        #[cfg(target_arch = "x86_64")]
116        prefetch_slice!(G, bases, bases, prefetch_iter);
117
118        let (a, b) = if idx < idy {
119            let (x, y) = bases.split_at_mut(*idy as usize);
120            (&mut x[*idx as usize], y[0])
121        } else {
122            let (x, y) = bases.split_at_mut(*idx as usize);
123            (&mut y[0], x[*idy as usize])
124        };
125        G::batch_add_loop_2(a, b, &mut inversion_tmp);
126    }
127}
128
129/// If `(j, k)` is the `i`-th entry in `index`, then this method performs one of
130/// two actions:
131/// * `addition_result[i] = bases[j] + bases[k]`
132/// * `addition_result[i] = bases[j];
133///
134/// It uses `scratch_space` to store intermediate values, and clears it after
135/// use.
136#[inline]
137fn batch_add_write<G: AffineCurve>(
138    bases: &[G],
139    index: &[(u32, u32)],
140    addition_result: &mut Vec<G>,
141    scratch_space: &mut Vec<Option<G>>,
142) {
143    let mut inversion_tmp = G::BaseField::one();
144    let half = G::BaseField::half();
145
146    #[cfg(target_arch = "x86_64")]
147    let mut prefetch_iter = index.iter();
148    #[cfg(target_arch = "x86_64")]
149    prefetch_iter.next();
150
151    // We run two loops over the data separated by an inversion
152    for (idx, idy) in index.iter() {
153        #[cfg(target_arch = "x86_64")]
154        prefetch_slice_write!(G, bases, bases, prefetch_iter);
155
156        if *idy == !0u32 {
157            addition_result.push(bases[*idx as usize]);
158            scratch_space.push(None);
159        } else {
160            let (mut a, mut b) = (bases[*idx as usize], bases[*idy as usize]);
161            G::batch_add_loop_1(&mut a, &mut b, &half, &mut inversion_tmp);
162            addition_result.push(a);
163            scratch_space.push(Some(b));
164        }
165    }
166
167    inversion_tmp = inversion_tmp.inverse().unwrap(); // this is always in Fp*
168
169    for (a, op_b) in addition_result.iter_mut().rev().zip(scratch_space.iter().rev()) {
170        if let Some(b) = op_b {
171            G::batch_add_loop_2(a, *b, &mut inversion_tmp);
172        }
173    }
174    scratch_space.clear();
175}
176
177#[inline]
178pub(super) fn batch_add<G: AffineCurve>(
179    num_buckets: usize,
180    bases: &[G],
181    bucket_positions: &mut [BucketPosition],
182) -> Vec<G> {
183    assert!(bases.len() >= bucket_positions.len());
184    assert!(!bases.is_empty());
185
186    // Fetch the ideal batch size for the number of bases.
187    let batch_size = batch_size(bases.len());
188
189    // Sort the buckets by their bucket index (not scalar index).
190    bucket_positions.sort_unstable();
191
192    let mut num_scalars = bucket_positions.len();
193    let mut all_ones = true;
194    let mut new_scalar_length = 0;
195    let mut global_counter = 0;
196    let mut local_counter = 1;
197    let mut number_of_bases_in_batch = 0;
198
199    let mut instr = Vec::<(u32, u32)>::with_capacity(batch_size);
200    let mut new_bases = Vec::with_capacity(bases.len());
201    let mut scratch_space = Vec::with_capacity(batch_size / 2);
202
203    // In the first loop, copy the results of the first in-place addition tree to
204    // the vector `new_bases`.
205    while global_counter < num_scalars {
206        let current_bucket = bucket_positions[global_counter].bucket_index;
207        while global_counter + 1 < num_scalars && bucket_positions[global_counter + 1].bucket_index == current_bucket {
208            global_counter += 1;
209            local_counter += 1;
210        }
211        if current_bucket >= num_buckets as u32 {
212            local_counter = 1;
213        } else if local_counter > 1 {
214            // all ones is false if next len is not 1
215            if local_counter > 2 {
216                all_ones = false;
217            }
218            let is_odd = local_counter % 2 == 1;
219            let half = local_counter / 2;
220            for i in 0..half {
221                instr.push((
222                    bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
223                    bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
224                ));
225                bucket_positions[new_scalar_length + i] =
226                    BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + i) as u32 };
227            }
228            if is_odd {
229                instr.push((bucket_positions[global_counter].scalar_index, !0u32));
230                bucket_positions[new_scalar_length + half] =
231                    BucketPosition { bucket_index: current_bucket, scalar_index: (new_scalar_length + half) as u32 };
232            }
233            // Reset the local_counter and update state
234            new_scalar_length += half + (local_counter % 2);
235            number_of_bases_in_batch += half;
236            local_counter = 1;
237
238            // When the number of bases in a batch crosses the threshold, perform a batch
239            // addition.
240            if number_of_bases_in_batch >= batch_size / 2 {
241                // We need instructions for copying data in the case of noops.
242                // We encode noops/copies as !0u32
243                batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
244
245                instr.clear();
246                number_of_bases_in_batch = 0;
247            }
248        } else {
249            instr.push((bucket_positions[global_counter].scalar_index, !0u32));
250            bucket_positions[new_scalar_length] =
251                BucketPosition { bucket_index: current_bucket, scalar_index: new_scalar_length as u32 };
252            new_scalar_length += 1;
253        }
254        global_counter += 1;
255    }
256    if !instr.is_empty() {
257        batch_add_write(bases, &instr, &mut new_bases, &mut scratch_space);
258        instr.clear();
259    }
260    global_counter = 0;
261    number_of_bases_in_batch = 0;
262    local_counter = 1;
263    num_scalars = new_scalar_length;
264    new_scalar_length = 0;
265
266    // Next, perform all the updates in place.
267    while !all_ones {
268        all_ones = true;
269        while global_counter < num_scalars {
270            let current_bucket = bucket_positions[global_counter].bucket_index;
271            while global_counter + 1 < num_scalars
272                && bucket_positions[global_counter + 1].bucket_index == current_bucket
273            {
274                global_counter += 1;
275                local_counter += 1;
276            }
277            if current_bucket >= num_buckets as u32 {
278                local_counter = 1;
279            } else if local_counter > 1 {
280                // all ones is false if next len is not 1
281                if local_counter != 2 {
282                    all_ones = false;
283                }
284                let is_odd = local_counter % 2 == 1;
285                let half = local_counter / 2;
286                for i in 0..half {
287                    instr.push((
288                        bucket_positions[global_counter - (local_counter - 1) + 2 * i].scalar_index,
289                        bucket_positions[global_counter - (local_counter - 1) + 2 * i + 1].scalar_index,
290                    ));
291                    bucket_positions[new_scalar_length + i] =
292                        bucket_positions[global_counter - (local_counter - 1) + 2 * i];
293                }
294                if is_odd {
295                    bucket_positions[new_scalar_length + half] = bucket_positions[global_counter];
296                }
297                // Reset the local_counter and update state
298                new_scalar_length += half + (local_counter % 2);
299                number_of_bases_in_batch += half;
300                local_counter = 1;
301
302                if number_of_bases_in_batch >= batch_size / 2 {
303                    batch_add_in_place_same_slice(&mut new_bases, &instr);
304                    instr.clear();
305                    number_of_bases_in_batch = 0;
306                }
307            } else {
308                bucket_positions[new_scalar_length] = bucket_positions[global_counter];
309                new_scalar_length += 1;
310            }
311            global_counter += 1;
312        }
313        // If there are any remaining unprocessed instructions, proceed to perform batch
314        // addition.
315        if !instr.is_empty() {
316            batch_add_in_place_same_slice(&mut new_bases, &instr);
317            instr.clear();
318        }
319        global_counter = 0;
320        number_of_bases_in_batch = 0;
321        local_counter = 1;
322        num_scalars = new_scalar_length;
323        new_scalar_length = 0;
324    }
325
326    let mut res = vec![Zero::zero(); num_buckets];
327    for bucket_position in bucket_positions.iter().take(num_scalars) {
328        res[bucket_position.bucket_index as usize] = new_bases[bucket_position.scalar_index as usize];
329    }
330    res
331}
332
333#[inline]
334fn batched_window<G: AffineCurve>(
335    bases: &[G],
336    scalars: &[<G::ScalarField as PrimeField>::BigInteger],
337    w_start: usize,
338    c: usize,
339) -> (G::Projective, usize) {
340    // We don't need the "zero" bucket, so we only have 2^c - 1 buckets
341    let window_size = if (w_start % c) != 0 { w_start % c } else { c };
342    let num_buckets = (1 << window_size) - 1;
343
344    let mut bucket_positions: Vec<_> = scalars
345        .iter()
346        .enumerate()
347        .map(|(scalar_index, &scalar)| {
348            let mut scalar = scalar;
349
350            // We right-shift by w_start, thus getting rid of the lower bits.
351            scalar.divn(w_start as u32);
352
353            // We mod the remaining bits by the window size.
354            let scalar = (scalar.as_ref()[0] % (1 << c)) as i32;
355
356            BucketPosition { bucket_index: (scalar - 1) as u32, scalar_index: scalar_index as u32 }
357        })
358        .collect();
359
360    let buckets = batch_add(num_buckets, bases, &mut bucket_positions);
361
362    let mut res = G::Projective::zero();
363    let mut running_sum = G::Projective::zero();
364    for b in buckets.into_iter().rev() {
365        running_sum.add_assign_mixed(&b);
366        res += &running_sum;
367    }
368
369    (res, window_size)
370}
371
372pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
373    if bases.len() < 15 {
374        let num_bits = G::ScalarField::size_in_bits();
375        let bigint_size = <G::ScalarField as PrimeField>::BigInteger::NUM_LIMBS * 64;
376        let mut bits =
377            scalars.iter().map(|s| BitIteratorBE::new(s.as_ref()).skip(bigint_size - num_bits)).collect::<Vec<_>>();
378        let mut sum = G::Projective::zero();
379
380        let mut encountered_one = false;
381        for _ in 0..num_bits {
382            if encountered_one {
383                sum.double_in_place();
384            }
385            for (bits, base) in bits.iter_mut().zip(bases) {
386                if let Some(true) = bits.next() {
387                    sum.add_assign_mixed(base);
388                    encountered_one = true;
389                }
390            }
391        }
392        debug_assert!(bits.iter_mut().all(|b| b.next().is_none()));
393        sum
394    } else {
395        // Determine the bucket size `c` (chosen empirically).
396        let c = match scalars.len() < 32 {
397            true => 1,
398            false => crate::msm::ln_without_floats(scalars.len()) + 2,
399        };
400
401        let num_bits = <G::ScalarField as PrimeField>::size_in_bits();
402
403        // Each window is of size `c`.
404        // We divide up the bits 0..num_bits into windows of size `c`, and
405        // in parallel process each such window.
406        let window_sums: Vec<_> =
407            cfg_into_iter!(0..num_bits).step_by(c).map(|w_start| batched_window(bases, scalars, w_start, c)).collect();
408
409        // We store the sum for the lowest window.
410        let (lowest, window_sums) = window_sums.split_first().unwrap();
411
412        // We're traversing windows from high to low.
413        window_sums.iter().rev().fold(G::Projective::zero(), |mut total, (sum_i, window_size)| {
414            total += sum_i;
415            for _ in 0..*window_size {
416                total.double_in_place();
417            }
418            total
419        }) + lowest.0
420    }
421}