Skip to main content

sp1_recursion_circuit/basefold/
stacked.rs

1use super::RecursiveMultilinearPcsVerifier;
2use crate::{challenger::FieldChallengerVariable, sumcheck::evaluate_mle_ext};
3use slop_commit::Rounds;
4use slop_multilinear::{Mle, MleEval, Point};
5use sp1_primitives::{SP1ExtensionField, SP1Field};
6use sp1_recursion_compiler::{
7    circuit::CircuitV2Builder,
8    ir::{Builder, Ext, SymbolicExt},
9};
10
11#[derive(Clone)]
12pub struct RecursiveStackedPcsVerifier<P> {
13    pub recursive_pcs_verifier: P,
14    pub log_stacking_height: u32,
15}
16
17pub struct RecursiveStackedPcsProof<PcsProof, F, EF> {
18    pub batch_evaluations: Rounds<MleEval<Ext<F, EF>>>,
19    pub pcs_proof: PcsProof,
20}
21
22impl<P: RecursiveMultilinearPcsVerifier> RecursiveStackedPcsVerifier<P> {
23    pub const fn new(recursive_pcs_verifier: P, log_stacking_height: u32) -> Self {
24        Self { recursive_pcs_verifier, log_stacking_height }
25    }
26
27    pub fn verify_untrusted_evaluation(
28        &self,
29        builder: &mut Builder<P::Circuit>,
30        commitments: &[P::Commitment],
31        point: &Point<Ext<SP1Field, SP1ExtensionField>>,
32        proof: &RecursiveStackedPcsProof<P::Proof, SP1Field, SP1ExtensionField>,
33        evaluation_claim: SymbolicExt<SP1Field, SP1ExtensionField>,
34        challenger: &mut P::Challenger,
35    ) {
36        let claim_ext: Ext<_, _> = builder.eval(evaluation_claim);
37        challenger.observe_ext_element(builder, claim_ext);
38        let (batch_point, stack_point) =
39            point.split_at(point.dimension() - self.log_stacking_height as usize);
40        let batch_evaluations =
41            proof.batch_evaluations.iter().flatten().cloned().collect::<Mle<_>>();
42
43        builder.cycle_tracker_v2_enter("rizz - evaluate_mle_ext");
44        let expected_evaluation = evaluate_mle_ext(builder, batch_evaluations, batch_point)[0];
45        builder.assert_ext_eq(claim_ext, expected_evaluation);
46        builder.cycle_tracker_v2_exit();
47
48        builder.cycle_tracker_v2_enter("rizz - verify_untrusted_evaluations");
49        self.recursive_pcs_verifier.verify_untrusted_evaluations(
50            builder,
51            commitments,
52            stack_point,
53            &proof.batch_evaluations,
54            &proof.pcs_proof,
55            challenger,
56        );
57        builder.cycle_tracker_v2_exit();
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use rand::thread_rng;
64    use slop_challenger::IopCtx;
65    use slop_commit::Message;
66    use sp1_core_machine::utils::setup_logger;
67    use sp1_recursion_compiler::{circuit::AsmConfig, config::InnerConfig};
68    use std::{collections::VecDeque, marker::PhantomData, sync::Arc};
69
70    use slop_algebra::extension::BinomialExtensionField;
71    use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
72
73    use crate::{
74        basefold::{tcs::RecursiveMerkleTreeTcs, RecursiveBasefoldVerifier},
75        challenger::DuplexChallengerVariable,
76        witness::Witnessable,
77    };
78
79    use super::*;
80
81    use slop_basefold::{BasefoldVerifier, FriConfig};
82    use slop_basefold_prover::BasefoldProver;
83    use slop_challenger::CanObserve;
84
85    use slop_commit::Rounds;
86
87    use crate::challenger::CanObserveVariable;
88    use slop_multilinear::{Mle, MultilinearPcsProver};
89    use slop_stacked::StackedPcsProver;
90    use sp1_hypercube::{inner_perm, prover::SP1MerkleTreeProver};
91    use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler};
92    use sp1_recursion_executor::Executor;
93
94    use sp1_primitives::SP1Field;
95    type F = SP1Field;
96
97    fn test_round_widths_and_log_heights(
98        round_widths_and_log_heights: &[Vec<(usize, u32)>],
99        log_stacking_height: u32,
100        batch_size: usize,
101    ) {
102        type C = InnerConfig;
103        type SC = SP1GlobalContext;
104        type Prover = BasefoldProver<SP1GlobalContext, SP1MerkleTreeProver>;
105        type EF = BinomialExtensionField<SP1Field, 4>;
106        let total_data_length = round_widths_and_log_heights
107            .iter()
108            .map(|dims| dims.iter().map(|&(w, log_h)| w << log_h).sum::<usize>())
109            .sum::<usize>();
110        let total_number_of_variables = total_data_length.next_power_of_two().ilog2();
111        assert_eq!(1 << total_number_of_variables, total_data_length);
112
113        let mut rng = thread_rng();
114        let round_mles = round_widths_and_log_heights
115            .iter()
116            .map(|dims| {
117                dims.iter()
118                    .map(|&(w, log_h)| Mle::<SP1Field>::rand(&mut rng, w, log_h))
119                    .collect::<Message<_>>()
120            })
121            .collect::<Rounds<_>>();
122
123        let pcs_verifier = BasefoldVerifier::<SC>::new(
124            FriConfig::default_fri_config(),
125            round_widths_and_log_heights.len(),
126        );
127        let pcs_prover = Prover::new(&pcs_verifier);
128
129        let prover = StackedPcsProver::new(pcs_prover, log_stacking_height, batch_size);
130
131        let mut challenger = SC::default_challenger();
132        let mut commitments = vec![];
133        let mut prover_data = Rounds::new();
134        let mut batch_evaluations = Rounds::new();
135        let point = Point::<EF>::rand(&mut rng, total_number_of_variables);
136
137        let (batch_point, stack_point) =
138            point.split_at(point.dimension() - log_stacking_height as usize);
139        for mles in round_mles.iter() {
140            let (commitment, data, _) = prover.commit_multilinear(mles.clone()).unwrap();
141            challenger.observe(commitment);
142            commitments.push(commitment);
143            let evaluations = prover.round_batch_evaluations(&stack_point, &data);
144            prover_data.push(data);
145            batch_evaluations.push(evaluations);
146        }
147
148        // Interpolate the batch evaluations as a multilinear polynomial.
149        let batch_evaluations_mle =
150            batch_evaluations.iter().flatten().flatten().cloned().collect::<Mle<_>>();
151        // Verify that the climed evaluations matched the interpolated evaluations.
152        let eval_claim = batch_evaluations_mle.eval_at(&batch_point)[0];
153
154        let proof = prover
155            .prove_untrusted_evaluation(point.clone(), eval_claim, prover_data, &mut challenger)
156            .unwrap();
157
158        let mut builder = AsmBuilder::default();
159        let mut witness_stream = Vec::new();
160        let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
161
162        Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
163        let commitments = commitments.read(&mut builder);
164
165        for commitment in commitments.iter() {
166            challenger_variable.observe(&mut builder, *commitment);
167        }
168
169        Witnessable::<AsmConfig>::write(&point, &mut witness_stream);
170        let point = point.read(&mut builder);
171
172        Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
173        let proof = proof.read(&mut builder);
174
175        Witnessable::<AsmConfig>::write(&eval_claim, &mut witness_stream);
176        let eval_claim = eval_claim.read(&mut builder);
177
178        let verifier = BasefoldVerifier::<SC>::new(
179            FriConfig::default_fri_config(),
180            round_widths_and_log_heights.len(),
181        );
182        let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
183            fri_config: verifier.fri_config,
184            tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
185        };
186        let recursive_verifier =
187            RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
188
189        recursive_verifier.verify_untrusted_evaluation(
190            &mut builder,
191            &commitments,
192            &point,
193            &proof,
194            eval_claim.into(),
195            &mut challenger_variable,
196        );
197
198        let mut buf = VecDeque::<u8>::new();
199        let block = builder.into_root_block();
200        let mut compiler = AsmCompiler::default();
201        let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
202        let mut executor =
203            Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
204        executor.witness_stream = witness_stream.into();
205        executor.debug_stdout = Box::new(&mut buf);
206        executor.run().unwrap();
207    }
208
209    #[test]
210    fn test_stacked_pcs_proof() {
211        setup_logger();
212        let round_widths_and_log_heights: Vec<(usize, u32)> =
213            vec![(1 << 10, 10), (1 << 4, 11), (496, 11)];
214        test_round_widths_and_log_heights(&[round_widths_and_log_heights], 10, 10);
215    }
216
217    #[test]
218    #[ignore = "should be invoked specifically"]
219    fn test_stacked_pcs_proof_core_shard() {
220        setup_logger();
221        let round_widths_and_log_heights = [vec![
222            (30, 21),
223            (44, 21),
224            (45, 21),
225            (18, 20),
226            (400, 18),
227            (25, 20),
228            (100, 20),
229            (40, 19),
230            (22, 19),
231        ]];
232        test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 1);
233        test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 5);
234    }
235
236    #[test]
237    #[ignore = "should be invoked specifically"]
238    fn test_stacked_pcs_proof_precompile_shard() {
239        setup_logger();
240        let round_widths_and_log_heights = [vec![(4000, 16), (400, 19), (20, 20), (21, 21)]];
241        test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 1);
242        test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 5);
243    }
244}