sp1_recursion_circuit_v2/
stark.rs

1use hashbrown::HashMap;
2use itertools::{izip, Itertools};
3use p3_commit::Mmcs;
4use p3_matrix::dense::RowMajorMatrix;
5
6use p3_air::Air;
7use p3_baby_bear::BabyBear;
8use p3_commit::{Pcs, TwoAdicMultiplicativeCoset};
9use p3_field::TwoAdicField;
10use sp1_stark::{ShardCommitment, ShardOpenedValues, Val};
11
12use p3_commit::PolynomialSpace;
13
14use sp1_recursion_compiler::{
15    circuit::CircuitV2Builder,
16    ir::{Builder, Config, Ext},
17    prelude::Felt,
18};
19use sp1_stark::{air::MachineAir, StarkGenericConfig, StarkMachine, StarkVerifyingKey};
20
21use crate::{
22    challenger::CanObserveVariable, CircuitConfig, TwoAdicPcsMatsVariable, TwoAdicPcsProofVariable,
23};
24
25use crate::{
26    challenger::FieldChallengerVariable, constraints::RecursiveVerifierConstraintFolder,
27    domain::PolynomialSpaceVariable, fri::verify_two_adic_pcs, BabyBearFriConfigVariable,
28    TwoAdicPcsRoundVariable, VerifyingKeyVariable,
29};
30
31/// Reference: [sp1_core::stark::ShardProof]
32#[derive(Clone)]
33pub struct ShardProofVariable<C: CircuitConfig<F = SC::Val>, SC: BabyBearFriConfigVariable<C>> {
34    pub commitment: ShardCommitment<SC::Digest>,
35    pub opened_values: ShardOpenedValues<Ext<C::F, C::EF>>,
36    pub opening_proof: TwoAdicPcsProofVariable<C, SC>,
37    pub chip_ordering: HashMap<String, usize>,
38    pub public_values: Vec<Felt<C::F>>,
39}
40
41pub const EMPTY: usize = 0x_1111_1111;
42
43#[derive(Debug, Clone, Copy)]
44pub struct StarkVerifier<C: Config, SC: StarkGenericConfig, A> {
45    _phantom: std::marker::PhantomData<(C, SC, A)>,
46}
47
48pub struct VerifyingKeyHint<'a, SC: StarkGenericConfig, A> {
49    pub machine: &'a StarkMachine<SC, A>,
50    pub vk: &'a StarkVerifyingKey<SC>,
51}
52
53impl<'a, SC: StarkGenericConfig, A: MachineAir<SC::Val>> VerifyingKeyHint<'a, SC, A> {
54    pub const fn new(machine: &'a StarkMachine<SC, A>, vk: &'a StarkVerifyingKey<SC>) -> Self {
55        Self { machine, vk }
56    }
57}
58
59impl<C, SC, A> StarkVerifier<C, SC, A>
60where
61    C::F: TwoAdicField,
62    C: CircuitConfig<F = SC::Val>,
63    SC: BabyBearFriConfigVariable<C>,
64    <SC::ValMmcs as Mmcs<BabyBear>>::ProverData<RowMajorMatrix<BabyBear>>: Clone,
65    A: MachineAir<Val<SC>>,
66{
67    pub fn natural_domain_for_degree(
68        config: &SC,
69        degree: usize,
70    ) -> TwoAdicMultiplicativeCoset<C::F> {
71        <SC::Pcs as Pcs<SC::Challenge, SC::FriChallenger>>::natural_domain_for_degree(
72            config.pcs(),
73            degree,
74        )
75    }
76
77    pub fn verify_shard(
78        builder: &mut Builder<C>,
79        vk: &VerifyingKeyVariable<C, SC>,
80        machine: &StarkMachine<SC, A>,
81        challenger: &mut SC::FriChallengerVariable,
82        proof: &ShardProofVariable<C, SC>,
83    ) where
84        A: for<'a> Air<RecursiveVerifierConstraintFolder<'a, C>>,
85    {
86        let chips = machine.shard_chips_ordered(&proof.chip_ordering).collect::<Vec<_>>();
87
88        let ShardProofVariable {
89            commitment,
90            opened_values,
91            opening_proof,
92            chip_ordering,
93            public_values,
94        } = proof;
95
96        let log_degrees = opened_values.chips.iter().map(|val| val.log_degree).collect::<Vec<_>>();
97
98        let log_quotient_degrees =
99            chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
100
101        let trace_domains = log_degrees
102            .iter()
103            .map(|log_degree| Self::natural_domain_for_degree(machine.config(), 1 << log_degree))
104            .collect::<Vec<_>>();
105
106        let ShardCommitment { main_commit, permutation_commit, quotient_commit } = *commitment;
107
108        let permutation_challenges =
109            (0..2).map(|_| challenger.sample_ext(builder)).collect::<Vec<_>>();
110
111        challenger.observe(builder, permutation_commit);
112
113        let alpha = challenger.sample_ext(builder);
114
115        challenger.observe(builder, quotient_commit);
116
117        let zeta = challenger.sample_ext(builder);
118
119        let preprocessed_domains_points_and_opens = vk
120            .chip_information
121            .iter()
122            .map(|(name, domain, _)| {
123                let i = chip_ordering[name];
124                let values = opened_values.chips[i].preprocessed.clone();
125                TwoAdicPcsMatsVariable::<C> {
126                    domain: *domain,
127                    points: vec![zeta, domain.next_point_variable(builder, zeta)],
128                    values: vec![values.local, values.next],
129                }
130            })
131            .collect::<Vec<_>>();
132
133        let main_domains_points_and_opens = trace_domains
134            .iter()
135            .zip_eq(opened_values.chips.iter())
136            .map(|(domain, values)| TwoAdicPcsMatsVariable::<C> {
137                domain: *domain,
138                points: vec![zeta, domain.next_point_variable(builder, zeta)],
139                values: vec![values.main.local.clone(), values.main.next.clone()],
140            })
141            .collect::<Vec<_>>();
142
143        let perm_domains_points_and_opens = trace_domains
144            .iter()
145            .zip_eq(opened_values.chips.iter())
146            .map(|(domain, values)| TwoAdicPcsMatsVariable::<C> {
147                domain: *domain,
148                points: vec![zeta, domain.next_point_variable(builder, zeta)],
149                values: vec![values.permutation.local.clone(), values.permutation.next.clone()],
150            })
151            .collect::<Vec<_>>();
152
153        let quotient_chunk_domains = trace_domains
154            .iter()
155            .zip_eq(log_degrees)
156            .zip_eq(log_quotient_degrees)
157            .map(|((domain, log_degree), log_quotient_degree)| {
158                let quotient_degree = 1 << log_quotient_degree;
159                let quotient_domain =
160                    domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree));
161                quotient_domain.split_domains(quotient_degree)
162            })
163            .collect::<Vec<_>>();
164
165        let quotient_domains_points_and_opens = proof
166            .opened_values
167            .chips
168            .iter()
169            .zip_eq(quotient_chunk_domains.iter())
170            .flat_map(|(values, qc_domains)| {
171                values.quotient.iter().zip_eq(qc_domains).map(move |(values, q_domain)| {
172                    TwoAdicPcsMatsVariable::<C> {
173                        domain: *q_domain,
174                        points: vec![zeta],
175                        values: vec![values.clone()],
176                    }
177                })
178            })
179            .collect::<Vec<_>>();
180
181        // Create the pcs rounds.
182        let prep_commit = vk.commitment;
183        let prep_round = TwoAdicPcsRoundVariable {
184            batch_commit: prep_commit,
185            domains_points_and_opens: preprocessed_domains_points_and_opens,
186        };
187        let main_round = TwoAdicPcsRoundVariable {
188            batch_commit: main_commit,
189            domains_points_and_opens: main_domains_points_and_opens,
190        };
191        let perm_round = TwoAdicPcsRoundVariable {
192            batch_commit: permutation_commit,
193            domains_points_and_opens: perm_domains_points_and_opens,
194        };
195        let quotient_round = TwoAdicPcsRoundVariable {
196            batch_commit: quotient_commit,
197            domains_points_and_opens: quotient_domains_points_and_opens,
198        };
199        let rounds = vec![prep_round, main_round, perm_round, quotient_round];
200
201        // Verify the pcs proof
202        builder.cycle_tracker_v2_enter("stage-d-verify-pcs".to_string());
203        let config = machine.config().fri_config();
204        verify_two_adic_pcs::<C, SC>(builder, config, opening_proof, challenger, rounds);
205        builder.cycle_tracker_v2_exit();
206
207        // Verify the constrtaint evaluations.
208        builder.cycle_tracker_v2_enter("stage-e-verify-constraints".to_string());
209        for (chip, trace_domain, qc_domains, values) in
210            izip!(chips.iter(), trace_domains, quotient_chunk_domains, opened_values.chips.iter(),)
211        {
212            // Verify the shape of the opening arguments matches the expected values.
213            Self::verify_opening_shape(chip, values).unwrap();
214            // Verify the constraint evaluation.
215            Self::verify_constraints(
216                builder,
217                chip,
218                values,
219                trace_domain,
220                qc_domains,
221                zeta,
222                alpha,
223                &permutation_challenges,
224                public_values,
225            );
226        }
227        builder.cycle_tracker_v2_exit();
228    }
229}
230
231impl<C: CircuitConfig<F = SC::Val>, SC: BabyBearFriConfigVariable<C>> ShardProofVariable<C, SC> {
232    pub fn contains_cpu(&self) -> bool {
233        self.chip_ordering.contains_key("CPU")
234    }
235
236    pub fn contains_memory_init(&self) -> bool {
237        self.chip_ordering.contains_key("MemoryInit")
238    }
239
240    pub fn contains_memory_finalize(&self) -> bool {
241        self.chip_ordering.contains_key("MemoryFinalize")
242    }
243}
244
245#[allow(unused_imports)]
246#[cfg(any(test, feature = "export-tests"))]
247pub mod tests {
248    use std::collections::VecDeque;
249
250    use crate::{
251        challenger::{CanCopyChallenger, CanObserveVariable, DuplexChallengerVariable},
252        utils::tests::run_test_recursion_with_prover,
253        BabyBearFriConfig,
254    };
255
256    use sp1_core_executor::{programs::tests::FIBONACCI_ELF, Program};
257    use sp1_core_machine::{
258        io::SP1Stdin,
259        riscv::RiscvAir,
260        utils::{prove, setup_logger},
261    };
262    use sp1_recursion_compiler::{
263        config::{InnerConfig, OuterConfig},
264        ir::{Builder, DslIr, TracedVec},
265    };
266
267    use sp1_recursion_core_v2::{
268        air::Block, machine::RecursionAir, stark::config::BabyBearPoseidon2Outer,
269    };
270    use sp1_stark::{
271        baby_bear_poseidon2::BabyBearPoseidon2, CpuProver, InnerVal, MachineProver, SP1CoreOpts,
272        ShardProof,
273    };
274
275    use super::*;
276    use crate::witness::*;
277
278    type F = InnerVal;
279    type A = RiscvAir<F>;
280
281    pub fn build_verify_shard_with_provers<
282        C: CircuitConfig<F = InnerVal, Bit = Felt<InnerVal>>,
283        SC: BabyBearFriConfigVariable<C> + Default + Sync + Send,
284        CoreP: MachineProver<SC, A>,
285        RecP: MachineProver<SC, RecursionAir<F, 3, 0>>,
286    >(
287        config: SC,
288        elf: &[u8],
289        opts: SP1CoreOpts,
290        num_shards_in_batch: Option<usize>,
291    ) -> (TracedVec<DslIr<C>>, Vec<Block<BabyBear>>)
292    where
293        SC::Challenger: Send,
294        <<SC as BabyBearFriConfig>::ValMmcs as Mmcs<BabyBear>>::ProverData<
295            RowMajorMatrix<BabyBear>,
296        >: Send + Sync,
297        <<SC as BabyBearFriConfig>::ValMmcs as Mmcs<BabyBear>>::Commitment: Send + Sync,
298        <<SC as BabyBearFriConfig>::ValMmcs as Mmcs<BabyBear>>::Proof: Send,
299        StarkVerifyingKey<SC>: Witnessable<C, WitnessVariable = VerifyingKeyVariable<C, SC>>,
300        ShardProof<SC>: Witnessable<C, WitnessVariable = ShardProofVariable<C, SC>>,
301    {
302        // Generate a dummy proof.
303        setup_logger();
304
305        let machine = RiscvAir::<C::F>::machine(SC::default());
306        let (_, vk) = machine.setup(&Program::from(elf).unwrap());
307        let (proof, _, _) =
308            prove::<_, CoreP>(Program::from(elf).unwrap(), &SP1Stdin::new(), SC::default(), opts)
309                .unwrap();
310        let mut challenger = machine.config().challenger();
311        machine.verify(&vk, &proof, &mut challenger).unwrap();
312        println!("Proof generated successfully");
313
314        // Observe all the commitments.
315        let mut builder = Builder::<C>::default();
316
317        let mut witness_stream = Vec::<WitnessBlock<C>>::new();
318
319        // Add a hash invocation, since the poseidon2 table expects that it's in the first row.
320        let mut challenger = config.challenger_variable(&mut builder);
321        // let vk = VerifyingKeyVariable::from_constant_key_babybear(&mut builder, &vk);
322        vk.write(&mut witness_stream);
323        let vk: VerifyingKeyVariable<_, _> = vk.read(&mut builder);
324        vk.observe_into(&mut builder, &mut challenger);
325
326        let proofs = proof
327            .shard_proofs
328            .into_iter()
329            .map(|proof| {
330                proof.write(&mut witness_stream);
331                proof.read(&mut builder)
332            })
333            .collect::<Vec<_>>();
334        // Observe all the commitments, and put the proofs into the witness stream.
335        for proof in proofs.iter() {
336            let ShardCommitment { main_commit, .. } = proof.commitment;
337            challenger.observe(&mut builder, main_commit);
338            let pv_slice = &proof.public_values[..machine.num_pv_elts()];
339            challenger.observe_slice(&mut builder, pv_slice.iter().cloned());
340        }
341        // Verify the first proof.
342        let num_shards = num_shards_in_batch.unwrap_or(proofs.len());
343        for proof in proofs.into_iter().take(num_shards) {
344            let mut challenger = challenger.copy(&mut builder);
345            StarkVerifier::verify_shard(&mut builder, &vk, &machine, &mut challenger, &proof);
346        }
347        (builder.operations, witness_stream)
348    }
349
350    #[test]
351    fn test_verify_shard_inner() {
352        let (operations, stream) =
353            build_verify_shard_with_provers::<
354                InnerConfig,
355                BabyBearPoseidon2,
356                CpuProver<_, _>,
357                CpuProver<_, _>,
358            >(BabyBearPoseidon2::new(), FIBONACCI_ELF, SP1CoreOpts::default(), Some(2));
359        run_test_recursion_with_prover::<CpuProver<_, _>>(operations, stream);
360    }
361}