Skip to main content

sp1_recursion_circuit/jagged/
verifier.rs

1use std::marker::PhantomData;
2
3use itertools::izip;
4use slop_algebra::AbstractField;
5use slop_jagged::{JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof};
6use slop_multilinear::{Mle, MleEval, Point};
7use slop_sumcheck::PartialSumcheckProof;
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10    circuit::CircuitV2Builder,
11    ir::{Builder, Ext, Felt, SymbolicExt},
12};
13
14use crate::{
15    basefold::{
16        stacked::{RecursiveStackedPcsProof, RecursiveStackedPcsVerifier},
17        RecursiveBasefoldProof, RecursiveBasefoldVerifier,
18    },
19    challenger::FieldChallengerVariable,
20    sumcheck::{evaluate_mle_ext, verify_sumcheck},
21    CircuitConfig, SP1FieldConfigVariable,
22};
23
24use super::jagged_eval::{RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig};
25
26pub struct RecursivePcsImpl<C, SC, P> {
27    _marker: PhantomData<(C, SC, P)>,
28}
29
30pub struct JaggedPcsProofVariable<Proof, Digest> {
31    pub params: JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
32    pub sumcheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
33    pub jagged_eval_proof: JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>,
34    pub pcs_proof: RecursiveStackedPcsProof<Proof, SP1Field, SP1ExtensionField>,
35    pub column_counts: Vec<Vec<usize>>,
36    pub row_counts: Vec<Vec<Felt<SP1Field>>>,
37    pub original_commitments: Vec<Digest>,
38    pub expected_eval: Ext<SP1Field, SP1ExtensionField>,
39}
40
41#[derive(Clone)]
42pub struct RecursiveJaggedPcsVerifier<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
43    pub stacked_pcs_verifier: RecursiveStackedPcsVerifier<RecursiveBasefoldVerifier<C, SC>>,
44    pub max_log_row_count: usize,
45    pub jagged_evaluator: RecursiveJaggedEvalSumcheckConfig<SC>,
46}
47
48impl<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> RecursiveJaggedPcsVerifier<SC, C> {
49    #[allow(clippy::too_many_arguments)]
50    pub fn verify_trusted_evaluations(
51        &self,
52        builder: &mut Builder<C>,
53        commitments: &[SC::DigestVariable],
54        point: Point<Ext<SP1Field, SP1ExtensionField>>,
55        evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
56        proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
57        insertion_points: &[usize],
58        challenger: &mut SC::FriChallengerVariable,
59    ) -> Vec<Felt<SP1Field>> {
60        let JaggedPcsProofVariable {
61            pcs_proof,
62            sumcheck_proof,
63            jagged_eval_proof,
64            params,
65            column_counts,
66            original_commitments,
67            expected_eval,
68            ..
69        } = proof;
70        let num_col_variables = (params.col_prefix_sums.len() - 1).next_power_of_two().ilog2();
71
72        let z_col =
73            (0..num_col_variables).map(|_| challenger.sample_ext(builder)).collect::<Point<_>>();
74
75        let z_row = point;
76
77        // Collect the claims for the different polynomials.
78        let mut column_claims = evaluation_claims.iter().flatten().copied().collect::<Vec<_>>();
79
80        let added_columns: Vec<usize> =
81            column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
82        // For each commit, Rizz needed a commitment to a vector of length a multiple of
83        // 1 << self.pcs.log_stacking_height, and this is achieved by adding a single column of
84        // zeroes as the last matrix of the commitment. We insert these "artificial" zeroes
85        // into the evaluation claims.
86        let zero_ext: Ext<SP1Field, SP1ExtensionField> =
87            builder.constant(SP1ExtensionField::zero());
88        for (insertion_point, num_added_columns) in
89            insertion_points.iter().rev().zip(added_columns.iter().rev())
90        {
91            for _ in 0..*num_added_columns {
92                column_claims.insert(*insertion_point, zero_ext);
93            }
94        }
95
96        for (round_column_counts, round_row_counts, modified_commitment, original_commitment) in izip!(
97            column_counts.iter(),
98            proof.row_counts.iter(),
99            commitments.iter(),
100            original_commitments.iter()
101        ) {
102            let mut felts_vec: Vec<Felt<_>> =
103                vec![builder.eval(SP1Field::from_canonical_usize(round_column_counts.len()))];
104            for &count in round_row_counts {
105                felts_vec.push(builder.eval(count));
106            }
107
108            for &count in round_column_counts {
109                felts_vec.push(builder.eval(SP1Field::from_canonical_usize(count)));
110            }
111            let hash = SC::hash(builder, &felts_vec);
112            let expected_commitment = SC::compress(builder, [*original_commitment, hash]);
113
114            SC::assert_digest_eq(builder, expected_commitment, *modified_commitment);
115        }
116
117        // Pad the column claims to the next power of two.
118        column_claims.resize(column_claims.len().next_power_of_two(), zero_ext);
119
120        let column_mle = Mle::from(column_claims);
121        let sumcheck_claim: Ext<SP1Field, SP1ExtensionField> =
122            evaluate_mle_ext(builder, column_mle, z_col.clone())[0];
123
124        builder.assert_ext_eq(sumcheck_claim, sumcheck_proof.claimed_sum);
125
126        builder.cycle_tracker_v2_enter("jagged - verify sumcheck");
127        verify_sumcheck::<C, SC>(builder, challenger, sumcheck_proof);
128        builder.cycle_tracker_v2_exit();
129
130        builder.cycle_tracker_v2_enter("jagged - jagged-eval");
131        let (jagged_eval, prefix_sum_felts) = self.jagged_evaluator.jagged_evaluation(
132            builder,
133            params,
134            z_row,
135            z_col,
136            sumcheck_proof.point_and_eval.0.clone(),
137            jagged_eval_proof,
138            challenger,
139        );
140        builder.cycle_tracker_v2_exit();
141
142        // Compute the expected evaluation of the dense trace polynomial.
143        builder.assert_ext_eq(jagged_eval * *expected_eval, sumcheck_proof.point_and_eval.1);
144
145        // Verify the evaluation proof.
146        let evaluation_point = sumcheck_proof.point_and_eval.0.clone();
147        self.stacked_pcs_verifier.verify_untrusted_evaluation(
148            builder,
149            original_commitments,
150            &evaluation_point,
151            pcs_proof,
152            SymbolicExt::from(*expected_eval),
153            challenger,
154        );
155        prefix_sum_felts
156    }
157}
158
159pub struct RecursiveMachineJaggedPcsVerifier<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
160    pub jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
161    pub column_counts_by_round: Vec<Vec<usize>>,
162}
163
164impl<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig>
165    RecursiveMachineJaggedPcsVerifier<'a, SC, C>
166{
167    pub fn new(
168        jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
169        column_counts_by_round: Vec<Vec<usize>>,
170    ) -> Self {
171        Self { jagged_pcs_verifier, column_counts_by_round }
172    }
173
174    pub fn verify_trusted_evaluations(
175        &self,
176        builder: &mut Builder<C>,
177        commitments: &[SC::DigestVariable],
178        point: Point<Ext<SP1Field, SP1ExtensionField>>,
179        evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
180        proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
181        challenger: &mut SC::FriChallengerVariable,
182    ) -> Vec<Felt<SP1Field>> {
183        let insertion_points = self
184            .column_counts_by_round
185            .iter()
186            .scan(0, |state, y| {
187                *state += y.iter().sum::<usize>();
188                Some(*state)
189            })
190            .collect::<Vec<_>>();
191
192        self.jagged_pcs_verifier.verify_trusted_evaluations(
193            builder,
194            commitments,
195            point,
196            evaluation_claims,
197            proof,
198            &insertion_points,
199            challenger,
200        )
201    }
202}
203
204#[cfg(test)]
205mod tests {
206    use std::{marker::PhantomData, sync::Arc};
207
208    use rand::{thread_rng, Rng};
209    use slop_algebra::AbstractField;
210    use slop_basefold::{BasefoldVerifier, FriConfig};
211    use slop_challenger::{CanObserve, IopCtx};
212    use slop_commit::Rounds;
213    use slop_jagged::{JaggedPcsProof, JaggedPcsVerifier, JaggedProver};
214    use slop_multilinear::{Evaluations, Mle, MleEval, PaddedMle, Point};
215    use sp1_core_machine::utils::setup_logger;
216    use sp1_hypercube::{
217        inner_perm, prover::SP1InnerPcsProver, SP1InnerPcs, SP1PcsProof, SP1PcsProofInner,
218    };
219    use sp1_primitives::{SP1DiffusionMatrix, SP1ExtensionField, SP1Field, SP1GlobalContext};
220    use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder};
221    use sp1_recursion_executor::Executor;
222
223    use crate::{
224        basefold::{
225            stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
226            RecursiveBasefoldVerifier,
227        },
228        challenger::{CanObserveVariable, DuplexChallengerVariable},
229        jagged::{
230            jagged_eval::RecursiveJaggedEvalSumcheckConfig,
231            verifier::{RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier},
232        },
233        witness::Witnessable,
234    };
235
236    type SC = SP1GlobalContext;
237    type JC = SP1InnerPcs;
238    type GC = SP1GlobalContext;
239    type F = SP1Field;
240    type EF = SP1ExtensionField;
241    type C = AsmConfig;
242    type Prover = JaggedProver<SP1GlobalContext, SP1PcsProofInner, SP1InnerPcsProver>;
243
244    #[allow(clippy::type_complexity)]
245    fn generate_jagged_proof(
246        jagged_verifier: &JaggedPcsVerifier<GC, JC>,
247        round_mles: Rounds<Vec<PaddedMle<F>>>,
248        eval_point: Point<EF>,
249    ) -> (
250        JaggedPcsProof<GC, SP1PcsProof<GC>>,
251        Rounds<<GC as IopCtx>::Digest>,
252        Rounds<Evaluations<EF>>,
253    ) {
254        let jagged_prover = Prover::from_verifier(jagged_verifier);
255
256        let mut challenger = jagged_verifier.challenger();
257
258        let mut prover_data = Rounds::new();
259        let mut commitments = Rounds::new();
260        for round in round_mles.iter() {
261            let (commit, data) = jagged_prover.commit_multilinears(round.clone()).ok().unwrap();
262            challenger.observe(commit);
263            let data_bytes = bincode::serialize(&data).unwrap();
264            let data = bincode::deserialize(&data_bytes).unwrap();
265            prover_data.push(data);
266            commitments.push(commit);
267        }
268
269        let mut evaluation_claims = Rounds::new();
270        for round in round_mles.iter() {
271            let mut evals = Evaluations::default();
272            for mle in round.iter() {
273                let eval = mle.eval_at(&eval_point);
274                evals.push(eval);
275            }
276            evaluation_claims.push(evals);
277        }
278
279        let proof = jagged_prover
280            .prove_trusted_evaluations(
281                eval_point.clone(),
282                evaluation_claims.clone(),
283                prover_data,
284                &mut challenger,
285            )
286            .ok()
287            .unwrap();
288
289        (proof, commitments, evaluation_claims)
290    }
291
292    #[test]
293    fn test_jagged_verifier() {
294        setup_logger();
295
296        let row_counts_rounds = vec![
297            vec![
298                1 << 13,
299                1 << 8,
300                1 << 11,
301                1 << 7,
302                1 << 16,
303                1 << 14,
304                1 << 20,
305                1 << 7,
306                1 << 9,
307                1 << 11,
308                1 << 8,
309                1 << 7,
310                1 << 14,
311                1 << 10,
312                1 << 14,
313                1 << 8,
314            ],
315            vec![1 << 8],
316        ];
317        let column_counts_rounds = vec![
318            vec![47, 41, 41, 58, 52, 109, 428, 50, 53, 93, 100, 83, 31, 68, 134, 80],
319            vec![512],
320        ];
321
322        let num_rounds = row_counts_rounds.len();
323
324        let log_stacking_height = 21;
325        let max_log_row_count = 20;
326
327        let row_counts = row_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
328        let column_counts = column_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
329
330        assert!(row_counts.len() == column_counts.len());
331
332        let mut rng = thread_rng();
333
334        let round_mles = row_counts
335            .iter()
336            .zip(column_counts.iter())
337            .map(|(row_counts, col_counts)| {
338                row_counts
339                    .iter()
340                    .zip(col_counts.iter())
341                    .map(|(num_rows, num_cols)| {
342                        if *num_rows == 0 {
343                            PaddedMle::zeros(*num_cols, max_log_row_count)
344                        } else {
345                            let mle = Mle::<F>::rand(&mut rng, *num_cols, num_rows.ilog(2));
346                            PaddedMle::padded_with_zeros(Arc::new(mle), max_log_row_count)
347                        }
348                    })
349                    .collect::<Vec<_>>()
350            })
351            .collect::<Rounds<_>>();
352
353        let jagged_verifier = JaggedPcsVerifier::<GC, JC>::new_from_basefold_params(
354            FriConfig::default_fri_config(),
355            log_stacking_height,
356            max_log_row_count as usize,
357            num_rounds,
358        );
359
360        let eval_point = (0..max_log_row_count).map(|_| rng.gen::<EF>()).collect::<Point<_>>();
361
362        // Generate the jagged proof.
363        let (proof, mut commitments, evaluation_claims) =
364            generate_jagged_proof(&jagged_verifier, round_mles, eval_point.clone());
365
366        let mut challenger = jagged_verifier.challenger();
367
368        for commitment in commitments.iter() {
369            // Ensure that the commitments are in the correct field.
370            challenger.observe(*commitment);
371        }
372
373        let evaluation_claims = evaluation_claims
374            .iter()
375            .map(|round| {
376                round.iter().flat_map(|evals| evals.iter().cloned()).collect::<MleEval<_>>()
377            })
378            .collect::<Vec<_>>();
379
380        jagged_verifier
381            .verify_trusted_evaluations(
382                &commitments,
383                eval_point.clone(),
384                &evaluation_claims,
385                &proof,
386                &mut challenger,
387            )
388            .unwrap();
389
390        // Define the verification circuit.
391        let mut builder = AsmBuilder::default();
392        builder.cycle_tracker_v2_enter("jagged - read input");
393        let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
394        let commitments_var = commitments.read(&mut builder);
395        let eval_point_var = eval_point.read(&mut builder);
396        let evaluation_claims_var = evaluation_claims.read(&mut builder);
397        let proof_var = proof.read(&mut builder);
398        builder.cycle_tracker_v2_exit();
399        builder.cycle_tracker_v2_enter("jagged - observe commitments");
400        for commitment_var in commitments_var.iter() {
401            challenger_variable.observe_slice(&mut builder, *commitment_var);
402        }
403        builder.cycle_tracker_v2_exit();
404        let verifier = BasefoldVerifier::<SC>::new(FriConfig::default_fri_config(), num_rounds);
405        let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
406            fri_config: verifier.fri_config,
407            tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
408        };
409        let recursive_verifier =
410            RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
411
412        let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<SC, C> {
413            stacked_pcs_verifier: recursive_verifier,
414            max_log_row_count: max_log_row_count as usize,
415            jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<SP1GlobalContext>(PhantomData),
416        };
417
418        let recursive_jagged_verifier = RecursiveMachineJaggedPcsVerifier::new(
419            &recursive_jagged_verifier,
420            vec![column_counts[0].clone(), column_counts[1].clone()],
421        );
422
423        builder.cycle_tracker_v2_enter("jagged-verifier");
424        recursive_jagged_verifier.verify_trusted_evaluations(
425            &mut builder,
426            &commitments_var,
427            eval_point_var,
428            &evaluation_claims_var,
429            &proof_var,
430            &mut challenger_variable,
431        );
432        builder.cycle_tracker_v2_exit();
433
434        let block = builder.into_root_block();
435        let mut compiler = AsmCompiler::default();
436
437        // Compile the verification circuit.
438        let program = compiler.compile_inner(block).validate().unwrap();
439
440        // Run the verification circuit with the proof artifacts.
441        let mut witness_stream = Vec::new();
442        Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
443        Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
444        Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
445        Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
446        let mut executor =
447            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program.clone()), inner_perm());
448        executor.witness_stream = witness_stream.into();
449        executor.run().unwrap();
450
451        // Run the verification circuit with the proof artifacts with an expected failure.
452        let mut witness_stream = Vec::new();
453        commitments.rounds[0][0] += F::one();
454        Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
455        Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
456        Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
457        Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
458        let mut executor =
459            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
460        executor.witness_stream = witness_stream.into();
461        executor.run().expect_err("invalid proof should not be verified");
462    }
463}