sp1_recursion_program/
stark.rs

1use p3_air::Air;
2use p3_commit::TwoAdicMultiplicativeCoset;
3use p3_field::{AbstractField, TwoAdicField};
4
5use sp1_recursion_compiler::{
6    ir::{Array, Builder, Config, Ext, ExtConst, SymbolicExt, SymbolicVar, Usize, Var},
7    prelude::Felt,
8};
9
10use sp1_recursion_core::runtime::DIGEST_SIZE;
11use sp1_stark::{
12    air::MachineAir, Com, GenericVerifierConstraintFolder, ShardProof, StarkGenericConfig,
13    StarkMachine, StarkVerifyingKey,
14};
15
16use crate::{
17    challenger::{CanObserveVariable, DuplexChallengerVariable, FeltChallenger},
18    commit::{PcsVariable, PolynomialSpaceVariable},
19    fri::{
20        types::{TwoAdicPcsMatsVariable, TwoAdicPcsRoundVariable},
21        TwoAdicFriPcsVariable, TwoAdicMultiplicativeCosetVariable,
22    },
23    types::{ShardCommitmentVariable, ShardProofVariable, VerifyingKeyVariable},
24};
25
26use crate::types::QuotientData;
27
28pub const EMPTY: usize = 0x_1111_1111;
29
30pub trait StarkRecursiveVerifier<C: Config> {
31    fn verify_shard(
32        &self,
33        builder: &mut Builder<C>,
34        vk: &VerifyingKeyVariable<C>,
35        pcs: &TwoAdicFriPcsVariable<C>,
36        challenger: &mut DuplexChallengerVariable<C>,
37        proof: &ShardProofVariable<C>,
38        is_complete: impl Into<SymbolicVar<C::N>>,
39    );
40
41    fn verify_shards(
42        &self,
43        builder: &mut Builder<C>,
44        vk: &VerifyingKeyVariable<C>,
45        pcs: &TwoAdicFriPcsVariable<C>,
46        challenger: &mut DuplexChallengerVariable<C>,
47        proofs: &Array<C, ShardProofVariable<C>>,
48        is_complete: impl Into<SymbolicVar<C::N>> + Clone,
49    ) {
50        // Assert that the number of shards is not zero.
51        builder.assert_usize_ne(proofs.len(), 0);
52
53        // Verify each shard.
54        builder.range(0, proofs.len()).for_each(|i, builder| {
55            let proof = builder.get(proofs, i);
56            self.verify_shard(builder, vk, pcs, challenger, &proof, is_complete.clone());
57        });
58    }
59}
60
61#[derive(Debug, Clone, Copy)]
62pub struct StarkVerifier<C: Config, SC: StarkGenericConfig> {
63    _phantom: std::marker::PhantomData<(C, SC)>,
64}
65
66pub struct ShardProofHint<'a, SC: StarkGenericConfig, A> {
67    pub machine: &'a StarkMachine<SC, A>,
68    pub proof: &'a ShardProof<SC>,
69}
70
71impl<'a, SC: StarkGenericConfig, A: MachineAir<SC::Val>> ShardProofHint<'a, SC, A> {
72    pub const fn new(machine: &'a StarkMachine<SC, A>, proof: &'a ShardProof<SC>) -> Self {
73        Self { machine, proof }
74    }
75}
76
77pub struct VerifyingKeyHint<'a, SC: StarkGenericConfig, A> {
78    pub machine: &'a StarkMachine<SC, A>,
79    pub vk: &'a StarkVerifyingKey<SC>,
80}
81
82impl<'a, SC: StarkGenericConfig, A: MachineAir<SC::Val>> VerifyingKeyHint<'a, SC, A> {
83    pub const fn new(machine: &'a StarkMachine<SC, A>, vk: &'a StarkVerifyingKey<SC>) -> Self {
84        Self { machine, vk }
85    }
86}
87
88pub type RecursiveVerifierConstraintFolder<'a, C> = GenericVerifierConstraintFolder<
89    'a,
90    <C as Config>::F,
91    <C as Config>::EF,
92    Felt<<C as Config>::F>,
93    Ext<<C as Config>::F, <C as Config>::EF>,
94    SymbolicExt<<C as Config>::F, <C as Config>::EF>,
95>;
96
97impl<C: Config, SC: StarkGenericConfig> StarkVerifier<C, SC>
98where
99    C::F: TwoAdicField,
100    SC: StarkGenericConfig<
101        Val = C::F,
102        Challenge = C::EF,
103        Domain = TwoAdicMultiplicativeCoset<C::F>,
104    >,
105{
106    pub fn verify_shard<A>(
107        builder: &mut Builder<C>,
108        vk: &VerifyingKeyVariable<C>,
109        pcs: &TwoAdicFriPcsVariable<C>,
110        machine: &StarkMachine<SC, A>,
111        challenger: &mut DuplexChallengerVariable<C>,
112        proof: &ShardProofVariable<C>,
113        check_cumulative_sum: bool,
114    ) where
115        A: MachineAir<C::F> + for<'a> Air<RecursiveVerifierConstraintFolder<'a, C>>,
116        C::F: TwoAdicField,
117        C::EF: TwoAdicField,
118        Com<SC>: Into<[SC::Val; DIGEST_SIZE]>,
119    {
120        builder.cycle_tracker("stage-c-verify-shard-setup");
121        let ShardProofVariable { commitment, opened_values, opening_proof, .. } = proof;
122
123        let ShardCommitmentVariable { main_commit, permutation_commit, quotient_commit } =
124            commitment;
125
126        let permutation_challenges =
127            (0..2).map(|_| challenger.sample_ext(builder)).collect::<Vec<_>>();
128
129        challenger.observe(builder, permutation_commit.clone());
130
131        let alpha = challenger.sample_ext(builder);
132
133        challenger.observe(builder, quotient_commit.clone());
134
135        let zeta = challenger.sample_ext(builder);
136
137        let num_shard_chips = opened_values.chips.len();
138        let mut trace_domains =
139            builder.dyn_array::<TwoAdicMultiplicativeCosetVariable<_>>(num_shard_chips);
140        let mut quotient_domains =
141            builder.dyn_array::<TwoAdicMultiplicativeCosetVariable<_>>(num_shard_chips);
142
143        let num_preprocessed_chips = machine.preprocessed_chip_ids().len();
144
145        let mut prep_mats: Array<_, TwoAdicPcsMatsVariable<_>> =
146            builder.dyn_array(num_preprocessed_chips);
147        let mut main_mats: Array<_, TwoAdicPcsMatsVariable<_>> = builder.dyn_array(num_shard_chips);
148        let mut perm_mats: Array<_, TwoAdicPcsMatsVariable<_>> = builder.dyn_array(num_shard_chips);
149
150        let num_quotient_mats: Var<_> = builder.eval(C::N::zero());
151        builder.range(0, num_shard_chips).for_each(|i, builder| {
152            let num_quotient_chunks = builder.get(&proof.quotient_data, i).quotient_size;
153            builder.assign(num_quotient_mats, num_quotient_mats + num_quotient_chunks);
154        });
155
156        let mut quotient_mats: Array<_, TwoAdicPcsMatsVariable<_>> =
157            builder.dyn_array(num_quotient_mats);
158
159        let mut qc_points = builder.dyn_array::<Ext<_, _>>(1);
160        builder.set_value(&mut qc_points, 0, zeta);
161
162        // Iterate through machine.chips filtered for preprocessed chips.
163        for (preprocessed_id, chip_id) in machine.preprocessed_chip_ids().into_iter().enumerate() {
164            // Get index within sorted preprocessed chips.
165            let preprocessed_sorted_id = builder.get(&vk.preprocessed_sorted_idxs, preprocessed_id);
166            // Get domain from witnessed domains. Array is ordered by machine.chips ordering.
167            let domain = builder.get(&vk.prep_domains, preprocessed_id);
168
169            // Get index within all sorted chips.
170            let chip_sorted_id = builder.get(&proof.sorted_idxs, chip_id);
171            // Get opening from proof.
172            let opening = builder.get(&opened_values.chips, chip_sorted_id);
173
174            let mut trace_points = builder.dyn_array::<Ext<_, _>>(2);
175            let zeta_next = domain.next_point(builder, zeta);
176
177            builder.set_value(&mut trace_points, 0, zeta);
178            builder.set_value(&mut trace_points, 1, zeta_next);
179
180            let mut prep_values = builder.dyn_array::<Array<C, _>>(2);
181            builder.set_value(&mut prep_values, 0, opening.preprocessed.local);
182            builder.set_value(&mut prep_values, 1, opening.preprocessed.next);
183            let main_mat = TwoAdicPcsMatsVariable::<C> {
184                domain: domain.clone(),
185                values: prep_values,
186                points: trace_points.clone(),
187            };
188            builder.set_value(&mut prep_mats, preprocessed_sorted_id, main_mat);
189        }
190
191        let qc_index: Var<_> = builder.eval(C::N::zero());
192        builder.range(0, num_shard_chips).for_each(|i, builder| {
193            let opening = builder.get(&opened_values.chips, i);
194            let QuotientData { log_quotient_degree, quotient_size } =
195                builder.get(&proof.quotient_data, i);
196            let domain = pcs.natural_domain_for_log_degree(builder, Usize::Var(opening.log_degree));
197            builder.set_value(&mut trace_domains, i, domain.clone());
198
199            let log_quotient_size: Usize<_> =
200                builder.eval(opening.log_degree + log_quotient_degree);
201            let quotient_domain =
202                domain.create_disjoint_domain(builder, log_quotient_size, Some(pcs.config.clone()));
203            builder.set_value(&mut quotient_domains, i, quotient_domain.clone());
204
205            // Get trace_opening_points.
206            let mut trace_points = builder.dyn_array::<Ext<_, _>>(2);
207            let zeta_next = domain.next_point(builder, zeta);
208            builder.set_value(&mut trace_points, 0, zeta);
209            builder.set_value(&mut trace_points, 1, zeta_next);
210
211            // Get the main matrix.
212            let mut main_values = builder.dyn_array::<Array<C, _>>(2);
213            builder.set_value(&mut main_values, 0, opening.main.local);
214            builder.set_value(&mut main_values, 1, opening.main.next);
215            let main_mat = TwoAdicPcsMatsVariable::<C> {
216                domain: domain.clone(),
217                values: main_values,
218                points: trace_points.clone(),
219            };
220            builder.set_value(&mut main_mats, i, main_mat);
221
222            // Get the permutation matrix.
223            let mut perm_values = builder.dyn_array::<Array<C, _>>(2);
224            builder.set_value(&mut perm_values, 0, opening.permutation.local);
225            builder.set_value(&mut perm_values, 1, opening.permutation.next);
226            let perm_mat = TwoAdicPcsMatsVariable::<C> {
227                domain: domain.clone(),
228                values: perm_values,
229                points: trace_points,
230            };
231            builder.set_value(&mut perm_mats, i, perm_mat);
232
233            // Get the quotient matrices and values.
234            let qc_domains =
235                quotient_domain.split_domains(builder, log_quotient_degree, quotient_size);
236
237            builder.range(0, qc_domains.len()).for_each(|j, builder| {
238                let qc_dom = builder.get(&qc_domains, j);
239                let qc_vals_array = builder.get(&opening.quotient, j);
240                let mut qc_values = builder.dyn_array::<Array<C, _>>(1);
241                builder.set_value(&mut qc_values, 0, qc_vals_array);
242                let qc_mat = TwoAdicPcsMatsVariable::<C> {
243                    domain: qc_dom,
244                    values: qc_values,
245                    points: qc_points.clone(),
246                };
247                builder.set_value(&mut quotient_mats, qc_index, qc_mat);
248                builder.assign(qc_index, qc_index + C::N::one());
249            });
250        });
251
252        // Create the pcs rounds.
253        let mut rounds = builder.dyn_array::<TwoAdicPcsRoundVariable<_>>(4);
254        let prep_commit = vk.commitment.clone();
255        let prep_round = TwoAdicPcsRoundVariable { batch_commit: prep_commit, mats: prep_mats };
256        let main_round =
257            TwoAdicPcsRoundVariable { batch_commit: main_commit.clone(), mats: main_mats };
258        let perm_round =
259            TwoAdicPcsRoundVariable { batch_commit: permutation_commit.clone(), mats: perm_mats };
260        let quotient_round =
261            TwoAdicPcsRoundVariable { batch_commit: quotient_commit.clone(), mats: quotient_mats };
262        builder.set_value(&mut rounds, 0, prep_round);
263        builder.set_value(&mut rounds, 1, main_round);
264        builder.set_value(&mut rounds, 2, perm_round);
265        builder.set_value(&mut rounds, 3, quotient_round);
266        builder.cycle_tracker("stage-c-verify-shard-setup");
267
268        // Verify the pcs proof
269        builder.cycle_tracker("stage-d-verify-pcs");
270        pcs.verify(builder, rounds, opening_proof.clone(), challenger);
271        builder.cycle_tracker("stage-d-verify-pcs");
272
273        builder.cycle_tracker("stage-e-verify-constraints");
274
275        let num_shard_chips_enabled: Var<_> = builder.eval(C::N::zero());
276        for (i, chip) in machine.chips().iter().enumerate() {
277            tracing::debug!("verifying constraints for chip: {}", chip.name());
278            let index = builder.get(&proof.sorted_idxs, i);
279
280            if chip.preprocessed_width() > 0 {
281                builder.assert_var_ne(index, C::N::from_canonical_usize(EMPTY));
282            }
283
284            builder.if_ne(index, C::N::from_canonical_usize(EMPTY)).then(|builder| {
285                let values = builder.get(&opened_values.chips, index);
286                let trace_domain = builder.get(&trace_domains, index);
287                let quotient_domain: TwoAdicMultiplicativeCosetVariable<_> =
288                    builder.get(&quotient_domains, index);
289
290                // Check that the quotient data matches the chip's data.
291                let log_quotient_degree = chip.log_quotient_degree();
292
293                let quotient_size = 1 << log_quotient_degree;
294                let chip_quotient_data = builder.get(&proof.quotient_data, index);
295                builder
296                    .assert_usize_eq(chip_quotient_data.log_quotient_degree, log_quotient_degree);
297                builder.assert_usize_eq(chip_quotient_data.quotient_size, quotient_size);
298
299                // Get the domains from the chip itself.
300                let qc_domains = quotient_domain.split_domains_const(builder, log_quotient_degree);
301
302                // Verify the constraints.
303                stacker::maybe_grow(16 * 1024 * 1024, 16 * 1024 * 1024, || {
304                    Self::verify_constraints(
305                        builder,
306                        chip,
307                        &values,
308                        proof.public_values.clone(),
309                        trace_domain,
310                        qc_domains,
311                        zeta,
312                        alpha,
313                        &permutation_challenges,
314                    );
315                });
316
317                // Increment the number of shard chips that are enabled.
318                builder.assign(num_shard_chips_enabled, num_shard_chips_enabled + C::N::one());
319            });
320        }
321
322        // Assert that the number of chips in `opened_values` matches the number of shard chips
323        // enabled.
324        builder.assert_var_eq(num_shard_chips_enabled, num_shard_chips);
325
326        // If we're checking the cumulative sum, assert that the sum of the cumulative sums is zero.
327        if check_cumulative_sum {
328            let sum: Ext<_, _> = builder.eval(C::EF::zero().cons());
329            builder.range(0, proof.opened_values.chips.len()).for_each(|i, builder| {
330                let cumulative_sum = builder.get(&proof.opened_values.chips, i).cumulative_sum;
331                builder.assign(sum, sum + cumulative_sum);
332            });
333            builder.assert_ext_eq(sum, C::EF::zero().cons());
334        }
335
336        builder.cycle_tracker("stage-e-verify-constraints");
337    }
338}
339
340#[cfg(test)]
341pub(crate) mod tests {
342    use std::{borrow::BorrowMut, time::Instant};
343
344    use crate::{
345        challenger::{CanObserveVariable, FeltChallenger},
346        hints::Hintable,
347        machine::commit_public_values,
348        stark::{DuplexChallengerVariable, Ext, ShardProofHint},
349        types::ShardCommitmentVariable,
350    };
351    use p3_challenger::{CanObserve, FieldChallenger};
352    use p3_field::AbstractField;
353    use rand::Rng;
354    use sp1_core_executor::Program;
355    use sp1_core_machine::{io::SP1Stdin, riscv::RiscvAir, utils::setup_logger};
356    use sp1_recursion_compiler::{
357        asm::AsmBuilder,
358        config::InnerConfig,
359        ir::{Array, Builder, Config, ExtConst, Felt, Usize},
360    };
361    use sp1_recursion_core::{
362        air::{
363            RecursionPublicValues, RECURSION_PUBLIC_VALUES_COL_MAP, RECURSIVE_PROOF_NUM_PV_ELTS,
364        },
365        runtime::{RecursionProgram, Runtime, DIGEST_SIZE},
366        stark::{
367            utils::{run_test_recursion, TestConfig},
368            RecursionAir,
369        },
370    };
371    use sp1_stark::{
372        air::POSEIDON_NUM_WORDS, baby_bear_poseidon2::BabyBearPoseidon2, CpuProver, InnerChallenge,
373        InnerVal, MachineProver, SP1CoreOpts, StarkGenericConfig,
374    };
375
376    type SC = BabyBearPoseidon2;
377    type Challenge = <SC as StarkGenericConfig>::Challenge;
378    type F = InnerVal;
379    type EF = InnerChallenge;
380    type C = InnerConfig;
381    type A = RiscvAir<F>;
382
383    #[test]
384    fn test_permutation_challenges() {
385        // Generate a dummy proof.
386        sp1_core_machine::utils::setup_logger();
387        let elf = include_bytes!("../../../../tests/fibonacci/elf/riscv32im-succinct-zkvm-elf");
388
389        let machine = A::machine(SC::default());
390        let (_, vk) = machine.setup(&Program::from(elf).unwrap());
391        let mut challenger_val = machine.config().challenger();
392        let (proof, _, _) = sp1_core_machine::utils::prove::<_, CpuProver<_, _>>(
393            Program::from(elf).unwrap(),
394            &SP1Stdin::new(),
395            SC::default(),
396            SP1CoreOpts::default(),
397        )
398        .unwrap();
399        let proofs = proof.shard_proofs;
400        println!("Proof generated successfully");
401
402        challenger_val.observe(vk.commit);
403
404        proofs.iter().for_each(|proof| {
405            challenger_val.observe(proof.commitment.main_commit);
406            challenger_val.observe_slice(&proof.public_values[0..machine.num_pv_elts()]);
407        });
408
409        let permutation_challenges =
410            (0..2).map(|_| challenger_val.sample_ext_element::<EF>()).collect::<Vec<_>>();
411
412        // Observe all the commitments.
413        let mut builder = Builder::<InnerConfig>::default();
414
415        // Add a hash invocation, since the poseidon2 table expects that it's in the first row.
416        let hash_input = builder.constant(vec![vec![F::one()]]);
417        builder.poseidon2_hash_x(&hash_input);
418
419        let mut challenger = DuplexChallengerVariable::new(&mut builder);
420
421        let preprocessed_commit_val: [F; DIGEST_SIZE] = vk.commit.into();
422        let preprocessed_commit: Array<C, _> = builder.constant(preprocessed_commit_val.to_vec());
423        challenger.observe(&mut builder, preprocessed_commit);
424
425        let mut witness_stream = Vec::new();
426        for proof in proofs {
427            let proof_hint = ShardProofHint::new(&machine, &proof);
428            witness_stream.extend(proof_hint.write());
429            let proof = ShardProofHint::<SC, A>::read(&mut builder);
430            let ShardCommitmentVariable { main_commit, .. } = proof.commitment;
431            challenger.observe(&mut builder, main_commit);
432            let pv_slice = proof.public_values.slice(
433                &mut builder,
434                Usize::Const(0),
435                Usize::Const(machine.num_pv_elts()),
436            );
437            challenger.observe_slice(&mut builder, pv_slice);
438        }
439
440        // Sample the permutation challenges.
441        let permutation_challenges_var =
442            (0..2).map(|_| challenger.sample_ext(&mut builder)).collect::<Vec<_>>();
443
444        for i in 0..2 {
445            builder.assert_ext_eq(permutation_challenges_var[i], permutation_challenges[i].cons());
446        }
447        builder.halt();
448
449        let program = builder.compile_program();
450        run_test_recursion(program, Some(witness_stream.into()), TestConfig::All);
451    }
452
453    fn test_public_values_program() -> RecursionProgram<InnerVal> {
454        let mut builder = Builder::<InnerConfig>::default();
455
456        // Add a hash invocation, since the poseidon2 table expects that it's in the first row.
457        let hash_input = builder.constant(vec![vec![F::one()]]);
458        builder.poseidon2_hash_x(&hash_input);
459
460        let mut public_values_stream: Vec<Felt<_>> =
461            (0..RECURSIVE_PROOF_NUM_PV_ELTS).map(|_| builder.uninit()).collect();
462
463        let public_values: &mut RecursionPublicValues<_> =
464            public_values_stream.as_mut_slice().borrow_mut();
465
466        public_values.sp1_vk_digest = [builder.constant(<C as Config>::F::zero()); DIGEST_SIZE];
467        public_values.next_pc = builder.constant(<C as Config>::F::one());
468        public_values.next_execution_shard = builder.constant(<C as Config>::F::two());
469        public_values.end_reconstruct_deferred_digest =
470            [builder.constant(<C as Config>::F::from_canonical_usize(3)); POSEIDON_NUM_WORDS];
471
472        public_values.deferred_proofs_digest =
473            [builder.constant(<C as Config>::F::from_canonical_usize(4)); POSEIDON_NUM_WORDS];
474
475        public_values.cumulative_sum =
476            [builder.constant(<C as Config>::F::from_canonical_usize(5)); 4];
477
478        commit_public_values(&mut builder, public_values);
479        builder.halt();
480
481        builder.compile_program()
482    }
483
484    #[test]
485    fn test_public_values_failure() {
486        let program = test_public_values_program();
487
488        let config = SC::default();
489
490        let mut runtime = Runtime::<InnerVal, Challenge, _>::new(&program, config.perm.clone());
491        runtime.run().unwrap();
492
493        let machine = RecursionAir::<_, 3>::machine(SC::default());
494        let prover = CpuProver::new(machine);
495        let (pk, vk) = prover.setup(&program);
496        let record = runtime.record.clone();
497
498        let mut challenger = prover.config().challenger();
499        let mut proof =
500            prover.prove(&pk, vec![record], &mut challenger, SP1CoreOpts::recursion()).unwrap();
501
502        let mut challenger = prover.config().challenger();
503        let verification_result = prover.machine().verify(&vk, &proof, &mut challenger);
504        if verification_result.is_err() {
505            panic!("Proof should verify successfully");
506        }
507
508        // Corrupt the public values.
509        proof.shard_proofs[0].public_values[RECURSION_PUBLIC_VALUES_COL_MAP.digest[0]] =
510            InnerVal::zero();
511        let verification_result = prover.machine().verify(&vk, &proof, &mut challenger);
512        if verification_result.is_ok() {
513            panic!("Proof should not verify successfully");
514        }
515    }
516
517    #[test]
518    #[ignore]
519    fn test_kitchen_sink() {
520        setup_logger();
521
522        let time = Instant::now();
523        let mut builder = AsmBuilder::<F, EF>::default();
524
525        let a: Felt<_> = builder.eval(F::from_canonical_u32(23));
526        let b: Felt<_> = builder.eval(F::from_canonical_u32(17));
527        let a_plus_b = builder.eval(a + b);
528        let mut rng = rand::thread_rng();
529        let a_ext_val = rng.gen::<EF>();
530        let b_ext_val = rng.gen::<EF>();
531        let a_ext: Ext<_, _> = builder.eval(a_ext_val.cons());
532        let b_ext: Ext<_, _> = builder.eval(b_ext_val.cons());
533        let a_plus_b_ext = builder.eval(a_ext + b_ext);
534        builder.print_f(a_plus_b);
535        builder.print_e(a_plus_b_ext);
536        builder.halt();
537
538        let program = builder.compile_program();
539        let elapsed = time.elapsed();
540        println!("Building took: {:?}", elapsed);
541
542        run_test_recursion(program, None, TestConfig::All);
543    }
544}