Skip to main content

sp1_recursion_circuit/jagged/
verifier.rs

1use std::{iter::repeat_n, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use slop_algebra::AbstractField;
5use slop_jagged::{JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof};
6use slop_multilinear::{Mle, MleEval, Point};
7use slop_sumcheck::PartialSumcheckProof;
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10    circuit::CircuitV2Builder,
11    ir::{Builder, Ext, Felt, SymbolicExt, SymbolicFelt},
12};
13
14use crate::{
15    basefold::{
16        stacked::{RecursiveStackedPcsProof, RecursiveStackedPcsVerifier},
17        RecursiveBasefoldProof, RecursiveBasefoldVerifier,
18    },
19    challenger::FieldChallengerVariable,
20    sumcheck::{evaluate_mle_ext, verify_sumcheck},
21    CircuitConfig, SP1FieldConfigVariable,
22};
23
24use super::jagged_eval::{RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig};
25
26pub struct RecursivePcsImpl<C, SC, P> {
27    _marker: PhantomData<(C, SC, P)>,
28}
29
30pub struct JaggedPcsProofVariable<Proof, Digest> {
31    pub params: JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
32    pub sumcheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
33    pub jagged_eval_proof: JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>,
34    pub pcs_proof: RecursiveStackedPcsProof<Proof, SP1Field, SP1ExtensionField>,
35    pub column_counts: Vec<Vec<usize>>,
36    pub row_counts: Vec<Vec<Felt<SP1Field>>>,
37    pub original_commitments: Vec<Digest>,
38    pub expected_eval: Ext<SP1Field, SP1ExtensionField>,
39}
40
41#[derive(Clone)]
42pub struct RecursiveJaggedPcsVerifier<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
43    pub stacked_pcs_verifier: RecursiveStackedPcsVerifier<RecursiveBasefoldVerifier<C, SC>>,
44    pub max_log_row_count: usize,
45    pub jagged_evaluator: RecursiveJaggedEvalSumcheckConfig<SC>,
46}
47
48impl<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> RecursiveJaggedPcsVerifier<SC, C> {
49    #[allow(clippy::too_many_arguments)]
50    pub fn verify_trusted_evaluations(
51        &self,
52        builder: &mut Builder<C>,
53        commitments: &[SC::DigestVariable],
54        point: Point<Ext<SP1Field, SP1ExtensionField>>,
55        evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
56        proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
57        insertion_points: &[usize],
58        challenger: &mut SC::FriChallengerVariable,
59    ) -> Vec<Felt<SP1Field>> {
60        let JaggedPcsProofVariable {
61            pcs_proof,
62            sumcheck_proof,
63            jagged_eval_proof,
64            params,
65            column_counts,
66            original_commitments,
67            expected_eval,
68            ..
69        } = proof;
70        let num_col_variables = (params.col_prefix_sums.len() - 1).next_power_of_two().ilog2();
71
72        let z_col =
73            (0..num_col_variables).map(|_| challenger.sample_ext(builder)).collect::<Point<_>>();
74
75        let z_row = point;
76
77        // Collect the claims for the different polynomials.
78        let mut column_claims = evaluation_claims.iter().flatten().copied().collect::<Vec<_>>();
79
80        let added_columns: Vec<usize> =
81            column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
82        // For each commit, Rizz needed a commitment to a vector of length a multiple of
83        // 1 << self.pcs.log_stacking_height, and this is achieved by adding a single column of
84        // zeroes as the last matrix of the commitment. We insert these "artificial" zeroes
85        // into the evaluation claims.
86        let zero_ext: Ext<SP1Field, SP1ExtensionField> =
87            builder.constant(SP1ExtensionField::zero());
88        for (insertion_point, num_added_columns) in
89            insertion_points.iter().rev().zip(added_columns.iter().rev())
90        {
91            for _ in 0..*num_added_columns {
92                column_claims.insert(*insertion_point, zero_ext);
93            }
94        }
95
96        for (round_column_counts, round_row_counts, modified_commitment, original_commitment) in izip!(
97            column_counts.iter(),
98            proof.row_counts.iter(),
99            commitments.iter(),
100            original_commitments.iter()
101        ) {
102            let mut felts_vec: Vec<Felt<_>> =
103                vec![builder.eval(SP1Field::from_canonical_usize(round_column_counts.len()))];
104            for &count in round_row_counts {
105                felts_vec.push(builder.eval(count));
106            }
107
108            for &count in round_column_counts {
109                felts_vec.push(builder.eval(SP1Field::from_canonical_usize(count)));
110            }
111            let hash = SC::hash(builder, &felts_vec);
112            let expected_commitment = SC::compress(builder, [*original_commitment, hash]);
113
114            SC::assert_digest_eq(builder, expected_commitment, *modified_commitment);
115        }
116
117        // Pad the column claims to the next power of two.
118        column_claims.resize(column_claims.len().next_power_of_two(), zero_ext);
119
120        let column_mle = Mle::from(column_claims);
121        let sumcheck_claim: Ext<SP1Field, SP1ExtensionField> =
122            evaluate_mle_ext(builder, column_mle, z_col.clone())[0];
123
124        builder.assert_ext_eq(sumcheck_claim, sumcheck_proof.claimed_sum);
125
126        builder.cycle_tracker_v2_enter("jagged - verify sumcheck");
127        verify_sumcheck::<C, SC>(builder, challenger, sumcheck_proof);
128        builder.cycle_tracker_v2_exit();
129
130        builder.cycle_tracker_v2_enter("jagged - jagged-eval");
131        let (jagged_eval, prefix_sum_felts) = self.jagged_evaluator.jagged_evaluation(
132            builder,
133            params,
134            z_row,
135            z_col,
136            sumcheck_proof.point_and_eval.0.clone(),
137            jagged_eval_proof,
138            challenger,
139        );
140        builder.cycle_tracker_v2_exit();
141
142        // Check the prefix_sum_felts against the row counts.
143        let repeated_flattened_row_counts: Vec<Felt<SP1Field>> = proof
144            .row_counts
145            .iter()
146            .flatten()
147            .zip_eq(column_counts.iter().flatten())
148            .flat_map(|(row, col)| repeat_n(*row, *col))
149            .collect();
150
151        let mut acc: Felt<_> = builder.constant(SP1Field::zero());
152
153        for (row_count, expected) in
154            repeated_flattened_row_counts.iter().zip_eq(prefix_sum_felts.iter())
155        {
156            builder.assert_felt_eq(acc, *expected);
157            acc = builder.eval(acc + *row_count)
158        }
159        let mut final_area = SymbolicFelt::zero();
160        let two: Felt<_> = builder.constant(SP1Field::two());
161        for bit in proof.params.col_prefix_sums.iter().last().unwrap().iter() {
162            final_area = *bit + two * final_area;
163        }
164        builder.assert_felt_eq(acc, final_area);
165
166        // Compute the expected evaluation of the dense trace polynomial.
167        builder.assert_ext_eq(jagged_eval * *expected_eval, sumcheck_proof.point_and_eval.1);
168
169        // Verify the evaluation proof.
170        let evaluation_point = sumcheck_proof.point_and_eval.0.clone();
171        self.stacked_pcs_verifier.verify_untrusted_evaluation(
172            builder,
173            original_commitments,
174            &evaluation_point,
175            pcs_proof,
176            SymbolicExt::from(*expected_eval),
177            challenger,
178        );
179        prefix_sum_felts
180    }
181}
182
183pub struct RecursiveMachineJaggedPcsVerifier<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
184    pub jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
185    pub column_counts_by_round: Vec<Vec<usize>>,
186}
187
188impl<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig>
189    RecursiveMachineJaggedPcsVerifier<'a, SC, C>
190{
191    pub fn new(
192        jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
193        column_counts_by_round: Vec<Vec<usize>>,
194    ) -> Self {
195        Self { jagged_pcs_verifier, column_counts_by_round }
196    }
197
198    pub fn verify_trusted_evaluations(
199        &self,
200        builder: &mut Builder<C>,
201        commitments: &[SC::DigestVariable],
202        point: Point<Ext<SP1Field, SP1ExtensionField>>,
203        evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
204        proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
205        challenger: &mut SC::FriChallengerVariable,
206    ) -> Vec<Felt<SP1Field>> {
207        let insertion_points = self
208            .column_counts_by_round
209            .iter()
210            .scan(0, |state, y| {
211                *state += y.iter().sum::<usize>();
212                Some(*state)
213            })
214            .collect::<Vec<_>>();
215
216        self.jagged_pcs_verifier.verify_trusted_evaluations(
217            builder,
218            commitments,
219            point,
220            evaluation_claims,
221            proof,
222            &insertion_points,
223            challenger,
224        )
225    }
226}
227
228#[cfg(test)]
229mod tests {
230    use std::{marker::PhantomData, sync::Arc};
231
232    use rand::{thread_rng, Rng};
233    use slop_algebra::AbstractField;
234    use slop_basefold::{BasefoldVerifier, FriConfig};
235    use slop_challenger::{CanObserve, IopCtx};
236    use slop_commit::Rounds;
237    use slop_jagged::{JaggedPcsProof, JaggedPcsVerifier, JaggedProver};
238    use slop_multilinear::{Evaluations, Mle, MleEval, PaddedMle, Point};
239    use sp1_core_machine::utils::setup_logger;
240    use sp1_hypercube::{
241        inner_perm, prover::SP1InnerPcsProver, SP1InnerPcs, SP1PcsProof, SP1PcsProofInner,
242    };
243    use sp1_primitives::{SP1DiffusionMatrix, SP1ExtensionField, SP1Field, SP1GlobalContext};
244    use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder};
245    use sp1_recursion_executor::Executor;
246
247    use crate::{
248        basefold::{
249            stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
250            RecursiveBasefoldVerifier,
251        },
252        challenger::{CanObserveVariable, DuplexChallengerVariable},
253        jagged::{
254            jagged_eval::RecursiveJaggedEvalSumcheckConfig,
255            verifier::{RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier},
256        },
257        witness::Witnessable,
258    };
259
260    type SC = SP1GlobalContext;
261    type JC = SP1InnerPcs;
262    type GC = SP1GlobalContext;
263    type F = SP1Field;
264    type EF = SP1ExtensionField;
265    type C = AsmConfig;
266    type Prover = JaggedProver<SP1GlobalContext, SP1PcsProofInner, SP1InnerPcsProver>;
267
268    #[allow(clippy::type_complexity)]
269    fn generate_jagged_proof(
270        jagged_verifier: &JaggedPcsVerifier<GC, JC>,
271        round_mles: Rounds<Vec<PaddedMle<F>>>,
272        eval_point: Point<EF>,
273    ) -> (
274        JaggedPcsProof<GC, SP1PcsProof<GC>>,
275        Rounds<<GC as IopCtx>::Digest>,
276        Rounds<Evaluations<EF>>,
277    ) {
278        let jagged_prover = Prover::from_verifier(jagged_verifier);
279
280        let mut challenger = jagged_verifier.challenger();
281
282        let mut prover_data = Rounds::new();
283        let mut commitments = Rounds::new();
284        for round in round_mles.iter() {
285            let (commit, data) = jagged_prover.commit_multilinears(round.clone()).ok().unwrap();
286            challenger.observe(commit);
287            let data_bytes = bincode::serialize(&data).unwrap();
288            let data = bincode::deserialize(&data_bytes).unwrap();
289            prover_data.push(data);
290            commitments.push(commit);
291        }
292
293        let mut evaluation_claims = Rounds::new();
294        for round in round_mles.iter() {
295            let mut evals = Evaluations::default();
296            for mle in round.iter() {
297                let eval = mle.eval_at(&eval_point);
298                evals.push(eval);
299            }
300            evaluation_claims.push(evals);
301        }
302
303        let proof = jagged_prover
304            .prove_trusted_evaluations(
305                eval_point.clone(),
306                evaluation_claims.clone(),
307                prover_data,
308                &mut challenger,
309            )
310            .ok()
311            .unwrap();
312
313        (proof, commitments, evaluation_claims)
314    }
315
316    #[test]
317    fn test_jagged_verifier() {
318        setup_logger();
319
320        let row_counts_rounds = vec![
321            vec![
322                1 << 13,
323                1 << 8,
324                1 << 11,
325                1 << 7,
326                1 << 16,
327                1 << 14,
328                1 << 20,
329                1 << 7,
330                1 << 9,
331                1 << 11,
332                1 << 8,
333                1 << 7,
334                1 << 14,
335                1 << 10,
336                1 << 14,
337                1 << 8,
338            ],
339            vec![1 << 8],
340        ];
341        let column_counts_rounds = vec![
342            vec![47, 41, 41, 58, 52, 109, 428, 50, 53, 93, 100, 83, 31, 68, 134, 80],
343            vec![512],
344        ];
345
346        let num_rounds = row_counts_rounds.len();
347
348        let log_stacking_height = 21;
349        let max_log_row_count = 20;
350
351        let row_counts = row_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
352        let column_counts = column_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
353
354        assert!(row_counts.len() == column_counts.len());
355
356        let mut rng = thread_rng();
357
358        let round_mles = row_counts
359            .iter()
360            .zip(column_counts.iter())
361            .map(|(row_counts, col_counts)| {
362                row_counts
363                    .iter()
364                    .zip(col_counts.iter())
365                    .map(|(num_rows, num_cols)| {
366                        if *num_rows == 0 {
367                            PaddedMle::zeros(*num_cols, max_log_row_count)
368                        } else {
369                            let mle = Mle::<F>::rand(&mut rng, *num_cols, num_rows.ilog(2));
370                            PaddedMle::padded_with_zeros(Arc::new(mle), max_log_row_count)
371                        }
372                    })
373                    .collect::<Vec<_>>()
374            })
375            .collect::<Rounds<_>>();
376
377        let jagged_verifier = JaggedPcsVerifier::<GC, JC>::new_from_basefold_params(
378            FriConfig::default_fri_config(),
379            log_stacking_height,
380            max_log_row_count as usize,
381            num_rounds,
382        );
383
384        let eval_point = (0..max_log_row_count).map(|_| rng.gen::<EF>()).collect::<Point<_>>();
385
386        // Generate the jagged proof.
387        let (proof, mut commitments, evaluation_claims) =
388            generate_jagged_proof(&jagged_verifier, round_mles, eval_point.clone());
389
390        let mut challenger = jagged_verifier.challenger();
391
392        for commitment in commitments.iter() {
393            // Ensure that the commitments are in the correct field.
394            challenger.observe(*commitment);
395        }
396
397        let evaluation_claims = evaluation_claims
398            .iter()
399            .map(|round| {
400                round.iter().flat_map(|evals| evals.iter().cloned()).collect::<MleEval<_>>()
401            })
402            .collect::<Vec<_>>();
403
404        jagged_verifier
405            .verify_trusted_evaluations(
406                &commitments,
407                eval_point.clone(),
408                &evaluation_claims,
409                &proof,
410                &mut challenger,
411            )
412            .unwrap();
413
414        // Define the verification circuit.
415        let mut builder = AsmBuilder::default();
416        builder.cycle_tracker_v2_enter("jagged - read input");
417        let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
418        let commitments_var = commitments.read(&mut builder);
419        let eval_point_var = eval_point.read(&mut builder);
420        let evaluation_claims_var = evaluation_claims.read(&mut builder);
421        let proof_var = proof.read(&mut builder);
422        builder.cycle_tracker_v2_exit();
423        builder.cycle_tracker_v2_enter("jagged - observe commitments");
424        for commitment_var in commitments_var.iter() {
425            challenger_variable.observe_slice(&mut builder, *commitment_var);
426        }
427        builder.cycle_tracker_v2_exit();
428        let verifier = BasefoldVerifier::<SC>::new(FriConfig::default_fri_config(), num_rounds);
429        let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
430            fri_config: verifier.fri_config,
431            tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
432        };
433        let recursive_verifier =
434            RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
435
436        let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<SC, C> {
437            stacked_pcs_verifier: recursive_verifier,
438            max_log_row_count: max_log_row_count as usize,
439            jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<SP1GlobalContext>(PhantomData),
440        };
441
442        let recursive_jagged_verifier = RecursiveMachineJaggedPcsVerifier::new(
443            &recursive_jagged_verifier,
444            vec![column_counts[0].clone(), column_counts[1].clone()],
445        );
446
447        builder.cycle_tracker_v2_enter("jagged-verifier");
448        recursive_jagged_verifier.verify_trusted_evaluations(
449            &mut builder,
450            &commitments_var,
451            eval_point_var,
452            &evaluation_claims_var,
453            &proof_var,
454            &mut challenger_variable,
455        );
456        builder.cycle_tracker_v2_exit();
457
458        let block = builder.into_root_block();
459        let mut compiler = AsmCompiler::default();
460
461        // Compile the verification circuit.
462        let program = compiler.compile_inner(block).validate().unwrap();
463
464        // Run the verification circuit with the proof artifacts.
465        let mut witness_stream = Vec::new();
466        Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
467        Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
468        Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
469        Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
470        let mut executor =
471            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program.clone()), inner_perm());
472        executor.witness_stream = witness_stream.into();
473        executor.run().unwrap();
474
475        // Run the verification circuit with the proof artifacts with an expected failure.
476        let mut witness_stream = Vec::new();
477        commitments.rounds[0][0] += F::one();
478        Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
479        Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
480        Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
481        Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
482        let mut executor =
483            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
484        executor.witness_stream = witness_stream.into();
485        executor.run().expect_err("invalid proof should not be verified");
486    }
487}