snarkvm_algorithms/msm/variable_base/
mod.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16pub mod batched;
17pub mod standard;
18
19#[cfg(target_arch = "x86_64")]
20pub mod prefetch;
21
22use snarkvm_curves::{bls12_377::G1Affine, traits::AffineCurve};
23use snarkvm_fields::PrimeField;
24
25use core::any::TypeId;
26
27pub struct VariableBase;
28
29impl VariableBase {
30    pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
31        // For BLS12-377, we perform variable base MSM using a batched addition
32        // technique.
33        if TypeId::of::<G>() == TypeId::of::<G1Affine>() {
34            #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
35            // TODO SNP: where to set the threshold
36            if scalars.len() > 1024 {
37                let result = snarkvm_algorithms_cuda::msm::<G, G::Projective, <G::ScalarField as PrimeField>::BigInteger>(
38                    bases, scalars,
39                );
40                if let Ok(result) = result {
41                    return result;
42                }
43            }
44            batched::msm(bases, scalars)
45        }
46        // For all other curves, we perform variable base MSM using Pippenger's algorithm.
47        else {
48            standard::msm(bases, scalars)
49        }
50    }
51
52    #[cfg(test)]
53    fn msm_naive<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
54        use itertools::Itertools;
55        use snarkvm_utilities::BitIteratorBE;
56
57        bases.iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
58    }
59
60    #[cfg(test)]
61    fn msm_naive_parallel<G: AffineCurve>(
62        bases: &[G],
63        scalars: &[<G::ScalarField as PrimeField>::BigInteger],
64    ) -> G::Projective {
65        use rayon::prelude::*;
66        use snarkvm_utilities::BitIteratorBE;
67
68        bases.par_iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
69    }
70}
71
72#[cfg(test)]
73mod tests {
74    use super::*;
75    use snarkvm_curves::bls12_377::{Fr, G1Affine};
76    use snarkvm_fields::PrimeField;
77    use snarkvm_utilities::rand::TestRng;
78
79    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
80    use snarkvm_curves::ProjectiveCurve;
81
82    fn create_scalar_bases<G: AffineCurve<ScalarField = F>, F: PrimeField>(
83        rng: &mut TestRng,
84        size: usize,
85    ) -> (Vec<G>, Vec<F::BigInteger>) {
86        let bases = (0..size).map(|_| G::rand(rng)).collect::<Vec<_>>();
87        let scalars = (0..size).map(|_| F::rand(rng).to_bigint()).collect::<Vec<_>>();
88        (bases, scalars)
89    }
90
91    #[test]
92    fn test_msm() {
93        use snarkvm_curves::ProjectiveCurve;
94        for msm_size in [1, 5, 10, 50, 100, 500, 1000] {
95            let mut rng = TestRng::default();
96            let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, msm_size);
97
98            let naive_a = VariableBase::msm_naive(bases.as_slice(), scalars.as_slice()).to_affine();
99            let naive_b = VariableBase::msm_naive_parallel(bases.as_slice(), scalars.as_slice()).to_affine();
100            assert_eq!(naive_a, naive_b, "MSM size: {msm_size}");
101
102            let candidate = standard::msm(bases.as_slice(), scalars.as_slice()).to_affine();
103            assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
104
105            let candidate = batched::msm(bases.as_slice(), scalars.as_slice()).to_affine();
106            assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
107        }
108    }
109
110    #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
111    #[test]
112    fn test_msm_cuda() {
113        let mut rng = TestRng::default();
114        for i in 2..17 {
115            let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, 1 << i);
116            let rust = standard::msm(bases.as_slice(), scalars.as_slice());
117            let cuda = VariableBase::msm::<G1Affine>(bases.as_slice(), scalars.as_slice());
118            assert_eq!(rust.to_affine(), cuda.to_affine());
119        }
120    }
121}