sp1_recursion_circuit/basefold/
stacked.rs1use super::RecursiveMultilinearPcsVerifier;
2use crate::{challenger::FieldChallengerVariable, sumcheck::evaluate_mle_ext};
3use slop_commit::Rounds;
4use slop_multilinear::{Mle, MleEval, Point};
5use sp1_primitives::{SP1ExtensionField, SP1Field};
6use sp1_recursion_compiler::{
7 circuit::CircuitV2Builder,
8 ir::{Builder, Ext, SymbolicExt},
9};
10
11#[derive(Clone)]
12pub struct RecursiveStackedPcsVerifier<P> {
13 pub recursive_pcs_verifier: P,
14 pub log_stacking_height: u32,
15}
16
17pub struct RecursiveStackedPcsProof<PcsProof, F, EF> {
18 pub batch_evaluations: Rounds<MleEval<Ext<F, EF>>>,
19 pub pcs_proof: PcsProof,
20}
21
22impl<P: RecursiveMultilinearPcsVerifier> RecursiveStackedPcsVerifier<P> {
23 pub const fn new(recursive_pcs_verifier: P, log_stacking_height: u32) -> Self {
24 Self { recursive_pcs_verifier, log_stacking_height }
25 }
26
27 pub fn verify_untrusted_evaluation(
28 &self,
29 builder: &mut Builder<P::Circuit>,
30 commitments: &[P::Commitment],
31 point: &Point<Ext<SP1Field, SP1ExtensionField>>,
32 proof: &RecursiveStackedPcsProof<P::Proof, SP1Field, SP1ExtensionField>,
33 evaluation_claim: SymbolicExt<SP1Field, SP1ExtensionField>,
34 challenger: &mut P::Challenger,
35 ) {
36 let claim_ext: Ext<_, _> = builder.eval(evaluation_claim);
37 challenger.observe_ext_element(builder, claim_ext);
38 let (batch_point, stack_point) =
39 point.split_at(point.dimension() - self.log_stacking_height as usize);
40 let batch_evaluations =
41 proof.batch_evaluations.iter().flatten().cloned().collect::<Mle<_>>();
42
43 builder.cycle_tracker_v2_enter("rizz - evaluate_mle_ext");
44 let expected_evaluation = evaluate_mle_ext(builder, batch_evaluations, batch_point)[0];
45 builder.assert_ext_eq(claim_ext, expected_evaluation);
46 builder.cycle_tracker_v2_exit();
47
48 builder.cycle_tracker_v2_enter("rizz - verify_untrusted_evaluations");
49 self.recursive_pcs_verifier.verify_untrusted_evaluations(
50 builder,
51 commitments,
52 stack_point,
53 &proof.batch_evaluations,
54 &proof.pcs_proof,
55 challenger,
56 );
57 builder.cycle_tracker_v2_exit();
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use rand::thread_rng;
64 use slop_challenger::IopCtx;
65 use slop_commit::Message;
66 use sp1_core_machine::utils::setup_logger;
67 use sp1_recursion_compiler::{circuit::AsmConfig, config::InnerConfig};
68 use std::{collections::VecDeque, marker::PhantomData, sync::Arc};
69
70 use slop_algebra::extension::BinomialExtensionField;
71 use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
72
73 use crate::{
74 basefold::{tcs::RecursiveMerkleTreeTcs, RecursiveBasefoldVerifier},
75 challenger::DuplexChallengerVariable,
76 witness::Witnessable,
77 };
78
79 use super::*;
80
81 use slop_basefold::{BasefoldVerifier, FriConfig};
82 use slop_basefold_prover::BasefoldProver;
83 use slop_challenger::CanObserve;
84
85 use slop_commit::Rounds;
86
87 use crate::challenger::CanObserveVariable;
88 use slop_multilinear::{Mle, MultilinearPcsProver};
89 use slop_stacked::StackedPcsProver;
90 use sp1_hypercube::{inner_perm, prover::SP1MerkleTreeProver};
91 use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler};
92 use sp1_recursion_executor::Executor;
93
94 use sp1_primitives::SP1Field;
95 type F = SP1Field;
96
97 fn test_round_widths_and_log_heights(
98 round_widths_and_log_heights: &[Vec<(usize, u32)>],
99 log_stacking_height: u32,
100 batch_size: usize,
101 ) {
102 type C = InnerConfig;
103 type SC = SP1GlobalContext;
104 type Prover = BasefoldProver<SP1GlobalContext, SP1MerkleTreeProver>;
105 type EF = BinomialExtensionField<SP1Field, 4>;
106 let total_data_length = round_widths_and_log_heights
107 .iter()
108 .map(|dims| dims.iter().map(|&(w, log_h)| w << log_h).sum::<usize>())
109 .sum::<usize>();
110 let total_number_of_variables = total_data_length.next_power_of_two().ilog2();
111 assert_eq!(1 << total_number_of_variables, total_data_length);
112
113 let mut rng = thread_rng();
114 let round_mles = round_widths_and_log_heights
115 .iter()
116 .map(|dims| {
117 dims.iter()
118 .map(|&(w, log_h)| Mle::<SP1Field>::rand(&mut rng, w, log_h))
119 .collect::<Message<_>>()
120 })
121 .collect::<Rounds<_>>();
122
123 let pcs_verifier = BasefoldVerifier::<SC>::new(
124 FriConfig::default_fri_config(),
125 round_widths_and_log_heights.len(),
126 );
127 let pcs_prover = Prover::new(&pcs_verifier);
128
129 let prover = StackedPcsProver::new(pcs_prover, log_stacking_height, batch_size);
130
131 let mut challenger = SC::default_challenger();
132 let mut commitments = vec![];
133 let mut prover_data = Rounds::new();
134 let mut batch_evaluations = Rounds::new();
135 let point = Point::<EF>::rand(&mut rng, total_number_of_variables);
136
137 let (batch_point, stack_point) =
138 point.split_at(point.dimension() - log_stacking_height as usize);
139 for mles in round_mles.iter() {
140 let (commitment, data, _) = prover.commit_multilinear(mles.clone()).unwrap();
141 challenger.observe(commitment);
142 commitments.push(commitment);
143 let evaluations = prover.round_batch_evaluations(&stack_point, &data);
144 prover_data.push(data);
145 batch_evaluations.push(evaluations);
146 }
147
148 let batch_evaluations_mle =
150 batch_evaluations.iter().flatten().flatten().cloned().collect::<Mle<_>>();
151 let eval_claim = batch_evaluations_mle.eval_at(&batch_point)[0];
153
154 let proof = prover
155 .prove_untrusted_evaluation(point.clone(), eval_claim, prover_data, &mut challenger)
156 .unwrap();
157
158 let mut builder = AsmBuilder::default();
159 let mut witness_stream = Vec::new();
160 let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
161
162 Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
163 let commitments = commitments.read(&mut builder);
164
165 for commitment in commitments.iter() {
166 challenger_variable.observe(&mut builder, *commitment);
167 }
168
169 Witnessable::<AsmConfig>::write(&point, &mut witness_stream);
170 let point = point.read(&mut builder);
171
172 Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
173 let proof = proof.read(&mut builder);
174
175 Witnessable::<AsmConfig>::write(&eval_claim, &mut witness_stream);
176 let eval_claim = eval_claim.read(&mut builder);
177
178 let verifier = BasefoldVerifier::<SC>::new(
179 FriConfig::default_fri_config(),
180 round_widths_and_log_heights.len(),
181 );
182 let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
183 fri_config: verifier.fri_config,
184 tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
185 };
186 let recursive_verifier =
187 RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
188
189 recursive_verifier.verify_untrusted_evaluation(
190 &mut builder,
191 &commitments,
192 &point,
193 &proof,
194 eval_claim.into(),
195 &mut challenger_variable,
196 );
197
198 let mut buf = VecDeque::<u8>::new();
199 let block = builder.into_root_block();
200 let mut compiler = AsmCompiler::default();
201 let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
202 let mut executor =
203 Executor::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
204 executor.witness_stream = witness_stream.into();
205 executor.debug_stdout = Box::new(&mut buf);
206 executor.run().unwrap();
207 }
208
209 #[test]
210 fn test_stacked_pcs_proof() {
211 setup_logger();
212 let round_widths_and_log_heights: Vec<(usize, u32)> =
213 vec![(1 << 10, 10), (1 << 4, 11), (496, 11)];
214 test_round_widths_and_log_heights(&[round_widths_and_log_heights], 10, 10);
215 }
216
217 #[test]
218 #[ignore = "should be invoked specifically"]
219 fn test_stacked_pcs_proof_core_shard() {
220 setup_logger();
221 let round_widths_and_log_heights = [vec![
222 (30, 21),
223 (44, 21),
224 (45, 21),
225 (18, 20),
226 (400, 18),
227 (25, 20),
228 (100, 20),
229 (40, 19),
230 (22, 19),
231 ]];
232 test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 1);
233 test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 5);
234 }
235
236 #[test]
237 #[ignore = "should be invoked specifically"]
238 fn test_stacked_pcs_proof_precompile_shard() {
239 setup_logger();
240 let round_widths_and_log_heights = [vec![(4000, 16), (400, 19), (20, 20), (21, 21)]];
241 test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 1);
242 test_round_widths_and_log_heights(&round_widths_and_log_heights, 21, 5);
243 }
244}