snarkvm_algorithms/msm/variable_base/
mod.rs1pub mod batched;
17pub mod standard;
18
19#[cfg(target_arch = "x86_64")]
20pub mod prefetch;
21
22use snarkvm_curves::{bls12_377::G1Affine, traits::AffineCurve};
23use snarkvm_fields::PrimeField;
24
25use core::any::TypeId;
26
27pub struct VariableBase;
28
29impl VariableBase {
30 pub fn msm<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
31 if TypeId::of::<G>() == TypeId::of::<G1Affine>() {
34 #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
35 if scalars.len() > 1024 {
37 let result = snarkvm_algorithms_cuda::msm::<G, G::Projective, <G::ScalarField as PrimeField>::BigInteger>(
38 bases, scalars,
39 );
40 if let Ok(result) = result {
41 return result;
42 }
43 }
44 batched::msm(bases, scalars)
45 }
46 else {
48 standard::msm(bases, scalars)
49 }
50 }
51
52 #[cfg(test)]
53 fn msm_naive<G: AffineCurve>(bases: &[G], scalars: &[<G::ScalarField as PrimeField>::BigInteger]) -> G::Projective {
54 use itertools::Itertools;
55 use snarkvm_utilities::BitIteratorBE;
56
57 bases.iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
58 }
59
60 #[cfg(test)]
61 fn msm_naive_parallel<G: AffineCurve>(
62 bases: &[G],
63 scalars: &[<G::ScalarField as PrimeField>::BigInteger],
64 ) -> G::Projective {
65 use rayon::prelude::*;
66 use snarkvm_utilities::BitIteratorBE;
67
68 bases.par_iter().zip_eq(scalars).map(|(base, scalar)| base.mul_bits(BitIteratorBE::new(*scalar))).sum()
69 }
70}
71
72#[cfg(test)]
73mod tests {
74 use super::*;
75 use snarkvm_curves::bls12_377::{Fr, G1Affine};
76 use snarkvm_fields::PrimeField;
77 use snarkvm_utilities::rand::TestRng;
78
79 #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
80 use snarkvm_curves::ProjectiveCurve;
81
82 fn create_scalar_bases<G: AffineCurve<ScalarField = F>, F: PrimeField>(
83 rng: &mut TestRng,
84 size: usize,
85 ) -> (Vec<G>, Vec<F::BigInteger>) {
86 let bases = (0..size).map(|_| G::rand(rng)).collect::<Vec<_>>();
87 let scalars = (0..size).map(|_| F::rand(rng).to_bigint()).collect::<Vec<_>>();
88 (bases, scalars)
89 }
90
91 #[test]
92 fn test_msm() {
93 use snarkvm_curves::ProjectiveCurve;
94 for msm_size in [1, 5, 10, 50, 100, 500, 1000] {
95 let mut rng = TestRng::default();
96 let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, msm_size);
97
98 let naive_a = VariableBase::msm_naive(bases.as_slice(), scalars.as_slice()).to_affine();
99 let naive_b = VariableBase::msm_naive_parallel(bases.as_slice(), scalars.as_slice()).to_affine();
100 assert_eq!(naive_a, naive_b, "MSM size: {msm_size}");
101
102 let candidate = standard::msm(bases.as_slice(), scalars.as_slice()).to_affine();
103 assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
104
105 let candidate = batched::msm(bases.as_slice(), scalars.as_slice()).to_affine();
106 assert_eq!(naive_a, candidate, "MSM size: {msm_size}");
107 }
108 }
109
110 #[cfg(all(feature = "cuda", target_arch = "x86_64"))]
111 #[test]
112 fn test_msm_cuda() {
113 let mut rng = TestRng::default();
114 for i in 2..17 {
115 let (bases, scalars) = create_scalar_bases::<G1Affine, Fr>(&mut rng, 1 << i);
116 let rust = standard::msm(bases.as_slice(), scalars.as_slice());
117 let cuda = VariableBase::msm::<G1Affine>(bases.as_slice(), scalars.as_slice());
118 assert_eq!(rust.to_affine(), cuda.to_affine());
119 }
120 }
121}