Skip to main content

taceo_groth16/
lib.rs

1use ark_ec::VariableBaseMSM;
2use ark_ec::pairing::Pairing;
3use ark_ec::{AffineRepr, CurveGroup};
4use ark_ff::{FftField, LegendreSymbol, PrimeField};
5use ark_poly::{EvaluationDomain, GeneralEvaluationDomain};
6use std::marker::PhantomData;
7use tracing::instrument;
8
9pub use ark_groth16::{Proof, ProvingKey, VerifyingKey};
10pub use ark_relations::r1cs::ConstraintMatrices;
11pub use reduction::{CircomReduction, LibSnarkReduction, R1CSToQAP};
12
13mod reduction;
14
15macro_rules! rayon_join3 {
16    ($t1: expr, $t2: expr, $t3: expr) => {{
17        let ((x, y), z) = rayon::join(|| rayon::join($t1, $t2), $t3);
18        (x, y, z)
19    }};
20}
21
22macro_rules! rayon_join5 {
23    ($t1: expr, $t2: expr, $t3: expr, $t4: expr, $t5: expr) => {{
24        let ((((v, w), x), y), z) = rayon::join(
25            || rayon::join(|| rayon::join(|| rayon::join($t1, $t2), $t3), $t4),
26            $t5,
27        );
28        (v, w, x, y, z)
29    }};
30}
31pub(crate) use rayon_join3;
32
33/// Computes the roots of unity over the provided prime field. This method
34/// is equivalent with [Circom's implementation](https://github.com/iden3/ffjavascript/blob/337b881579107ab74d5b2094dbe1910e33da4484/src/wasm_field1.js).
35///
36/// We calculate smallest quadratic non residue q (by checking q^((p-1)/2)=-1 mod p). We also calculate smallest t s.t. p-1=2^s*t, s is the two adicity.
37/// We use g=q^t (this is a 2^s-th root of unity) as (some kind of) generator and compute another domain by repeatedly squaring g, should get to 1 in the s+1-th step.
38/// Then if log2(\text{domain_size}) equals s we take q^2 as root of unity. Else we take the log2(\text{domain_size}) + 1-th element of the domain created above.
39fn roots_of_unity<F: PrimeField + FftField>() -> (F, Vec<F>) {
40    let mut roots = vec![F::zero(); F::TWO_ADICITY as usize + 1];
41    let mut q = F::one();
42    while q.legendre() != LegendreSymbol::QuadraticNonResidue {
43        q += F::one();
44    }
45    let z = q.pow(F::TRACE);
46    roots[0] = z;
47    for i in 1..roots.len() {
48        roots[i] = roots[i - 1].square();
49    }
50    roots.reverse();
51    (q, roots)
52}
53
54/* old way of computing root of unity, does not work for bls12_381:
55let root_of_unity = {
56    let domain_size_double = 2 * domain_size;
57    let domain_double =
58        D::new(domain_size_double).ok_or(SynthesisError::PolynomialDegreeTooLarge)?;
59    domain_double.element(1)
60};
61new one is computed in the same way as in snarkjs (More precisely in ffjavascript/src/wasm_field1.js)
62calculate smallest quadratic non residue q (by checking q^((p-1)/2)=-1 mod p) also calculate smallest t (F::TRACE) s.t. p-1=2^s*t, s is the two_adicity
63use g=q^t (this is a 2^s-th root of unity) as (some kind of) generator and compute another domain by repeatedly squaring g, should get to 1 in the s+1-th step.
64then if log2(domain_size) equals s we take as root of unity q^2, and else we take the log2(domain_size) + 1-th element of the domain created above
65*/
66#[instrument(level = "debug", name = "root of unity", skip_all)]
67fn root_of_unity_for_groth16<F: PrimeField + FftField>(
68    pow: usize,
69    domain: &mut GeneralEvaluationDomain<F>,
70) -> F {
71    let (q, roots) = roots_of_unity::<F>();
72    match domain {
73        GeneralEvaluationDomain::Radix2(domain) => {
74            domain.group_gen = roots[pow];
75            domain.group_gen_inv = domain.group_gen.inverse().expect("can compute inverse");
76        }
77        GeneralEvaluationDomain::MixedRadix(domain) => {
78            domain.group_gen = roots[pow];
79            domain.group_gen_inv = domain.group_gen.inverse().expect("can compute inverse");
80        }
81    };
82    if u64::from(F::TWO_ADICITY) == domain.log_size_of_group() {
83        q.square()
84    } else {
85        roots[domain.log_size_of_group() as usize + 1]
86    }
87}
88
89/// A Groth16 proof protocol.
90///
91/// This struct should never be initialized, it only provides associated functions [`Groth16::prove`] and [`Groth16::verify`].
92pub struct Groth16<P: Pairing> {
93    phantom_data: PhantomData<P>,
94}
95
96impl<P: Pairing> Groth16<P> {
97    #[instrument(level = "debug", name = "Groth16 - Proof", skip_all)]
98    pub fn prove<R: R1CSToQAP>(
99        pkey: &ProvingKey<P>,
100        r: P::ScalarField,
101        s: P::ScalarField,
102        matrices: &ConstraintMatrices<P::ScalarField>,
103        witness: &[P::ScalarField],
104    ) -> eyre::Result<Proof<P>> {
105        let witness_len = witness.len();
106        let witness_should_len = matrices.num_witness_variables + matrices.num_instance_variables;
107        if witness_len != witness_should_len {
108            eyre::bail!("expected witness len {witness_should_len}, got len {witness_len}",)
109        }
110        let h = R::witness_map_from_matrices::<P>(matrices, witness)?;
111        let proof = Self::create_proof_with_assignment(
112            pkey,
113            r,
114            s,
115            h,
116            witness,
117            matrices.num_instance_variables,
118        )?;
119        Ok(proof)
120    }
121
122    fn calculate_coeff<C>(
123        initial: C,
124        query: &[C::Affine],
125        vk_param: C::Affine,
126        witness: &[P::ScalarField],
127    ) -> C
128    where
129        C: CurveGroup<ScalarField = P::ScalarField>,
130    {
131        let acc = C::msm_unchecked(&query[1..], witness);
132        let mut res = initial;
133        res += query[0].into_group();
134        res += vk_param.into_group();
135        res += acc;
136        res
137    }
138
139    #[instrument(level = "debug", name = "create proof with assignment", skip_all)]
140    fn create_proof_with_assignment(
141        pkey: &ProvingKey<P>,
142        r: P::ScalarField,
143        s: P::ScalarField,
144        h: Vec<P::ScalarField>,
145        witness: &[P::ScalarField],
146        num_inputs: usize,
147    ) -> eyre::Result<Proof<P>> {
148        let delta_g1 = pkey.delta_g1.into_group();
149        let alpha_g1 = pkey.vk.alpha_g1;
150        let beta_g1 = pkey.beta_g1;
151        let beta_g2 = pkey.vk.beta_g2;
152        let delta_g2 = pkey.vk.delta_g2.into_group();
153
154        let (r_g1, s_g1, s_g2, l_acc, h_acc) = rayon_join5!(
155            || {
156                let compute_a =
157                    tracing::debug_span!("compute A in create proof with assignment").entered();
158                // Compute A
159                let r_g1 = delta_g1 * r;
160                let r_g1 = Self::calculate_coeff(r_g1, &pkey.a_query, alpha_g1, &witness[1..]);
161                compute_a.exit();
162                r_g1
163            },
164            || {
165                let compute_b =
166                    tracing::debug_span!("compute B/G1 in create proof with assignment").entered();
167                // Compute B in G1
168                // In original implementation this is skipped if r==0, however r is shared in our case
169                let s_g1 = delta_g1 * s;
170                let s_g1 = Self::calculate_coeff(s_g1, &pkey.b_g1_query, beta_g1, &witness[1..]);
171                compute_b.exit();
172                s_g1
173            },
174            || {
175                let compute_b =
176                    tracing::debug_span!("compute B/G2 in create proof with assignment").entered();
177                // Compute B in G2
178                let s_g2 = delta_g2 * s;
179                let s_g2 = Self::calculate_coeff(s_g2, &pkey.b_g2_query, beta_g2, &witness[1..]);
180                compute_b.exit();
181                s_g2
182            },
183            || {
184                let msm_l_query = tracing::debug_span!("msm l_query").entered();
185                let result = P::G1::msm_unchecked(&pkey.l_query, &witness[num_inputs..]);
186                msm_l_query.exit();
187                result
188            },
189            || {
190                let msm_h_query = tracing::debug_span!("msm h_query").entered();
191                //perform the msm for h
192                let result = P::G1::msm_unchecked(&pkey.h_query, &h);
193                msm_h_query.exit();
194                result
195            }
196        );
197
198        let rs = r * s;
199        let r_s_delta_g1 = delta_g1 * rs;
200
201        let g_a = r_g1;
202        let g1_b = s_g1;
203
204        let r_g1_b = g1_b * r;
205
206        let s_g_a = g_a * s;
207
208        let mut g_c = s_g_a;
209        g_c += r_g1_b;
210        g_c -= r_s_delta_g1;
211        g_c += l_acc;
212
213        g_c += h_acc;
214
215        let g2_b = s_g2;
216
217        Ok(Proof {
218            a: g_a.into_affine(),
219            b: g2_b.into_affine(),
220            c: g_c.into_affine(),
221        })
222    }
223}
224
225impl<P: Pairing> Groth16<P> {
226    /// Verify a Groth16 proof.
227    /// This method is a wrapper arkworks Groth16 and does not use MPC.
228    pub fn verify(
229        vk: &VerifyingKey<P>,
230        proof: &Proof<P>,
231        public_inputs: &[P::ScalarField],
232    ) -> eyre::Result<()> {
233        let vk = ark_groth16::prepare_verifying_key(vk);
234        let proof_valid = ark_groth16::Groth16::<P>::verify_proof(&vk, proof, public_inputs)
235            .map_err(eyre::Report::from)?;
236        if proof_valid {
237            Ok(())
238        } else {
239            Err(eyre::eyre!("invalid proof"))
240        }
241    }
242}