1use std::{iter::repeat_n, marker::PhantomData};
2
3use itertools::{izip, Itertools};
4use slop_algebra::AbstractField;
5use slop_jagged::{JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof};
6use slop_multilinear::{Mle, MleEval, Point};
7use slop_sumcheck::PartialSumcheckProof;
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10 circuit::CircuitV2Builder,
11 ir::{Builder, Ext, Felt, SymbolicExt, SymbolicFelt},
12};
13
14use crate::{
15 basefold::{
16 stacked::{RecursiveStackedPcsProof, RecursiveStackedPcsVerifier},
17 RecursiveBasefoldProof, RecursiveBasefoldVerifier,
18 },
19 challenger::FieldChallengerVariable,
20 sumcheck::{evaluate_mle_ext, verify_sumcheck},
21 CircuitConfig, SP1FieldConfigVariable,
22};
23
24use super::jagged_eval::{RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig};
25
26pub struct RecursivePcsImpl<C, SC, P> {
27 _marker: PhantomData<(C, SC, P)>,
28}
29
30pub struct JaggedPcsProofVariable<Proof, Digest> {
31 pub params: JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
32 pub sumcheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
33 pub jagged_eval_proof: JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>,
34 pub pcs_proof: RecursiveStackedPcsProof<Proof, SP1Field, SP1ExtensionField>,
35 pub column_counts: Vec<Vec<usize>>,
36 pub row_counts: Vec<Vec<Felt<SP1Field>>>,
37 pub original_commitments: Vec<Digest>,
38 pub expected_eval: Ext<SP1Field, SP1ExtensionField>,
39}
40
41#[derive(Clone)]
42pub struct RecursiveJaggedPcsVerifier<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
43 pub stacked_pcs_verifier: RecursiveStackedPcsVerifier<RecursiveBasefoldVerifier<C, SC>>,
44 pub max_log_row_count: usize,
45 pub jagged_evaluator: RecursiveJaggedEvalSumcheckConfig<SC>,
46}
47
48impl<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> RecursiveJaggedPcsVerifier<SC, C> {
49 #[allow(clippy::too_many_arguments)]
50 pub fn verify_trusted_evaluations(
51 &self,
52 builder: &mut Builder<C>,
53 commitments: &[SC::DigestVariable],
54 point: Point<Ext<SP1Field, SP1ExtensionField>>,
55 evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
56 proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
57 insertion_points: &[usize],
58 challenger: &mut SC::FriChallengerVariable,
59 ) -> Vec<Felt<SP1Field>> {
60 let JaggedPcsProofVariable {
61 pcs_proof,
62 sumcheck_proof,
63 jagged_eval_proof,
64 params,
65 column_counts,
66 original_commitments,
67 expected_eval,
68 ..
69 } = proof;
70 let num_col_variables = (params.col_prefix_sums.len() - 1).next_power_of_two().ilog2();
71
72 let z_col =
73 (0..num_col_variables).map(|_| challenger.sample_ext(builder)).collect::<Point<_>>();
74
75 let z_row = point;
76
77 let mut column_claims = evaluation_claims.iter().flatten().copied().collect::<Vec<_>>();
79
80 let added_columns: Vec<usize> =
81 column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
82 let zero_ext: Ext<SP1Field, SP1ExtensionField> =
87 builder.constant(SP1ExtensionField::zero());
88 for (insertion_point, num_added_columns) in
89 insertion_points.iter().rev().zip(added_columns.iter().rev())
90 {
91 for _ in 0..*num_added_columns {
92 column_claims.insert(*insertion_point, zero_ext);
93 }
94 }
95
96 for (round_column_counts, round_row_counts, modified_commitment, original_commitment) in izip!(
97 column_counts.iter(),
98 proof.row_counts.iter(),
99 commitments.iter(),
100 original_commitments.iter()
101 ) {
102 let mut felts_vec: Vec<Felt<_>> =
103 vec![builder.eval(SP1Field::from_canonical_usize(round_column_counts.len()))];
104 for &count in round_row_counts {
105 felts_vec.push(builder.eval(count));
106 }
107
108 for &count in round_column_counts {
109 felts_vec.push(builder.eval(SP1Field::from_canonical_usize(count)));
110 }
111 let hash = SC::hash(builder, &felts_vec);
112 let expected_commitment = SC::compress(builder, [*original_commitment, hash]);
113
114 SC::assert_digest_eq(builder, expected_commitment, *modified_commitment);
115 }
116
117 column_claims.resize(column_claims.len().next_power_of_two(), zero_ext);
119
120 let column_mle = Mle::from(column_claims);
121 let sumcheck_claim: Ext<SP1Field, SP1ExtensionField> =
122 evaluate_mle_ext(builder, column_mle, z_col.clone())[0];
123
124 builder.assert_ext_eq(sumcheck_claim, sumcheck_proof.claimed_sum);
125
126 builder.cycle_tracker_v2_enter("jagged - verify sumcheck");
127 verify_sumcheck::<C, SC>(builder, challenger, sumcheck_proof);
128 builder.cycle_tracker_v2_exit();
129
130 builder.cycle_tracker_v2_enter("jagged - jagged-eval");
131 let (jagged_eval, prefix_sum_felts) = self.jagged_evaluator.jagged_evaluation(
132 builder,
133 params,
134 z_row,
135 z_col,
136 sumcheck_proof.point_and_eval.0.clone(),
137 jagged_eval_proof,
138 challenger,
139 );
140 builder.cycle_tracker_v2_exit();
141
142 let repeated_flattened_row_counts: Vec<Felt<SP1Field>> = proof
144 .row_counts
145 .iter()
146 .flatten()
147 .zip_eq(column_counts.iter().flatten())
148 .flat_map(|(row, col)| repeat_n(*row, *col))
149 .collect();
150
151 let mut acc: Felt<_> = builder.constant(SP1Field::zero());
152
153 for (row_count, expected) in
154 repeated_flattened_row_counts.iter().zip_eq(prefix_sum_felts.iter())
155 {
156 builder.assert_felt_eq(acc, *expected);
157 acc = builder.eval(acc + *row_count)
158 }
159 let mut final_area = SymbolicFelt::zero();
160 let two: Felt<_> = builder.constant(SP1Field::two());
161 for bit in proof.params.col_prefix_sums.iter().last().unwrap().iter() {
162 final_area = *bit + two * final_area;
163 }
164 builder.assert_felt_eq(acc, final_area);
165
166 builder.assert_ext_eq(jagged_eval * *expected_eval, sumcheck_proof.point_and_eval.1);
168
169 let evaluation_point = sumcheck_proof.point_and_eval.0.clone();
171 self.stacked_pcs_verifier.verify_untrusted_evaluation(
172 builder,
173 original_commitments,
174 &evaluation_point,
175 pcs_proof,
176 SymbolicExt::from(*expected_eval),
177 challenger,
178 );
179 prefix_sum_felts
180 }
181}
182
183pub struct RecursiveMachineJaggedPcsVerifier<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
184 pub jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
185 pub column_counts_by_round: Vec<Vec<usize>>,
186}
187
188impl<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig>
189 RecursiveMachineJaggedPcsVerifier<'a, SC, C>
190{
191 pub fn new(
192 jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
193 column_counts_by_round: Vec<Vec<usize>>,
194 ) -> Self {
195 Self { jagged_pcs_verifier, column_counts_by_round }
196 }
197
198 pub fn verify_trusted_evaluations(
199 &self,
200 builder: &mut Builder<C>,
201 commitments: &[SC::DigestVariable],
202 point: Point<Ext<SP1Field, SP1ExtensionField>>,
203 evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
204 proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
205 challenger: &mut SC::FriChallengerVariable,
206 ) -> Vec<Felt<SP1Field>> {
207 let insertion_points = self
208 .column_counts_by_round
209 .iter()
210 .scan(0, |state, y| {
211 *state += y.iter().sum::<usize>();
212 Some(*state)
213 })
214 .collect::<Vec<_>>();
215
216 self.jagged_pcs_verifier.verify_trusted_evaluations(
217 builder,
218 commitments,
219 point,
220 evaluation_claims,
221 proof,
222 &insertion_points,
223 challenger,
224 )
225 }
226}
227
228#[cfg(test)]
229mod tests {
230 use std::{marker::PhantomData, sync::Arc};
231
232 use rand::{thread_rng, Rng};
233 use slop_algebra::AbstractField;
234 use slop_basefold::{BasefoldVerifier, FriConfig};
235 use slop_challenger::{CanObserve, IopCtx};
236 use slop_commit::Rounds;
237 use slop_jagged::{JaggedPcsProof, JaggedPcsVerifier, JaggedProver};
238 use slop_multilinear::{Evaluations, Mle, MleEval, PaddedMle, Point};
239 use sp1_core_machine::utils::setup_logger;
240 use sp1_hypercube::{
241 inner_perm, prover::SP1InnerPcsProver, SP1InnerPcs, SP1PcsProof, SP1PcsProofInner,
242 };
243 use sp1_primitives::{SP1DiffusionMatrix, SP1ExtensionField, SP1Field, SP1GlobalContext};
244 use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder};
245 use sp1_recursion_executor::Executor;
246
247 use crate::{
248 basefold::{
249 stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
250 RecursiveBasefoldVerifier,
251 },
252 challenger::{CanObserveVariable, DuplexChallengerVariable},
253 jagged::{
254 jagged_eval::RecursiveJaggedEvalSumcheckConfig,
255 verifier::{RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier},
256 },
257 witness::Witnessable,
258 };
259
260 type SC = SP1GlobalContext;
261 type JC = SP1InnerPcs;
262 type GC = SP1GlobalContext;
263 type F = SP1Field;
264 type EF = SP1ExtensionField;
265 type C = AsmConfig;
266 type Prover = JaggedProver<SP1GlobalContext, SP1PcsProofInner, SP1InnerPcsProver>;
267
268 #[allow(clippy::type_complexity)]
269 fn generate_jagged_proof(
270 jagged_verifier: &JaggedPcsVerifier<GC, JC>,
271 round_mles: Rounds<Vec<PaddedMle<F>>>,
272 eval_point: Point<EF>,
273 ) -> (
274 JaggedPcsProof<GC, SP1PcsProof<GC>>,
275 Rounds<<GC as IopCtx>::Digest>,
276 Rounds<Evaluations<EF>>,
277 ) {
278 let jagged_prover = Prover::from_verifier(jagged_verifier);
279
280 let mut challenger = jagged_verifier.challenger();
281
282 let mut prover_data = Rounds::new();
283 let mut commitments = Rounds::new();
284 for round in round_mles.iter() {
285 let (commit, data) = jagged_prover.commit_multilinears(round.clone()).ok().unwrap();
286 challenger.observe(commit);
287 let data_bytes = bincode::serialize(&data).unwrap();
288 let data = bincode::deserialize(&data_bytes).unwrap();
289 prover_data.push(data);
290 commitments.push(commit);
291 }
292
293 let mut evaluation_claims = Rounds::new();
294 for round in round_mles.iter() {
295 let mut evals = Evaluations::default();
296 for mle in round.iter() {
297 let eval = mle.eval_at(&eval_point);
298 evals.push(eval);
299 }
300 evaluation_claims.push(evals);
301 }
302
303 let proof = jagged_prover
304 .prove_trusted_evaluations(
305 eval_point.clone(),
306 evaluation_claims.clone(),
307 prover_data,
308 &mut challenger,
309 )
310 .ok()
311 .unwrap();
312
313 (proof, commitments, evaluation_claims)
314 }
315
316 #[test]
317 fn test_jagged_verifier() {
318 setup_logger();
319
320 let row_counts_rounds = vec![
321 vec![
322 1 << 13,
323 1 << 8,
324 1 << 11,
325 1 << 7,
326 1 << 16,
327 1 << 14,
328 1 << 20,
329 1 << 7,
330 1 << 9,
331 1 << 11,
332 1 << 8,
333 1 << 7,
334 1 << 14,
335 1 << 10,
336 1 << 14,
337 1 << 8,
338 ],
339 vec![1 << 8],
340 ];
341 let column_counts_rounds = vec![
342 vec![47, 41, 41, 58, 52, 109, 428, 50, 53, 93, 100, 83, 31, 68, 134, 80],
343 vec![512],
344 ];
345
346 let num_rounds = row_counts_rounds.len();
347
348 let log_stacking_height = 21;
349 let max_log_row_count = 20;
350
351 let row_counts = row_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
352 let column_counts = column_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
353
354 assert!(row_counts.len() == column_counts.len());
355
356 let mut rng = thread_rng();
357
358 let round_mles = row_counts
359 .iter()
360 .zip(column_counts.iter())
361 .map(|(row_counts, col_counts)| {
362 row_counts
363 .iter()
364 .zip(col_counts.iter())
365 .map(|(num_rows, num_cols)| {
366 if *num_rows == 0 {
367 PaddedMle::zeros(*num_cols, max_log_row_count)
368 } else {
369 let mle = Mle::<F>::rand(&mut rng, *num_cols, num_rows.ilog(2));
370 PaddedMle::padded_with_zeros(Arc::new(mle), max_log_row_count)
371 }
372 })
373 .collect::<Vec<_>>()
374 })
375 .collect::<Rounds<_>>();
376
377 let jagged_verifier = JaggedPcsVerifier::<GC, JC>::new_from_basefold_params(
378 FriConfig::default_fri_config(),
379 log_stacking_height,
380 max_log_row_count as usize,
381 num_rounds,
382 );
383
384 let eval_point = (0..max_log_row_count).map(|_| rng.gen::<EF>()).collect::<Point<_>>();
385
386 let (proof, mut commitments, evaluation_claims) =
388 generate_jagged_proof(&jagged_verifier, round_mles, eval_point.clone());
389
390 let mut challenger = jagged_verifier.challenger();
391
392 for commitment in commitments.iter() {
393 challenger.observe(*commitment);
395 }
396
397 let evaluation_claims = evaluation_claims
398 .iter()
399 .map(|round| {
400 round.iter().flat_map(|evals| evals.iter().cloned()).collect::<MleEval<_>>()
401 })
402 .collect::<Vec<_>>();
403
404 jagged_verifier
405 .verify_trusted_evaluations(
406 &commitments,
407 eval_point.clone(),
408 &evaluation_claims,
409 &proof,
410 &mut challenger,
411 )
412 .unwrap();
413
414 let mut builder = AsmBuilder::default();
416 builder.cycle_tracker_v2_enter("jagged - read input");
417 let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
418 let commitments_var = commitments.read(&mut builder);
419 let eval_point_var = eval_point.read(&mut builder);
420 let evaluation_claims_var = evaluation_claims.read(&mut builder);
421 let proof_var = proof.read(&mut builder);
422 builder.cycle_tracker_v2_exit();
423 builder.cycle_tracker_v2_enter("jagged - observe commitments");
424 for commitment_var in commitments_var.iter() {
425 challenger_variable.observe_slice(&mut builder, *commitment_var);
426 }
427 builder.cycle_tracker_v2_exit();
428 let verifier = BasefoldVerifier::<SC>::new(FriConfig::default_fri_config(), num_rounds);
429 let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
430 fri_config: verifier.fri_config,
431 tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
432 };
433 let recursive_verifier =
434 RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
435
436 let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<SC, C> {
437 stacked_pcs_verifier: recursive_verifier,
438 max_log_row_count: max_log_row_count as usize,
439 jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<SP1GlobalContext>(PhantomData),
440 };
441
442 let recursive_jagged_verifier = RecursiveMachineJaggedPcsVerifier::new(
443 &recursive_jagged_verifier,
444 vec![column_counts[0].clone(), column_counts[1].clone()],
445 );
446
447 builder.cycle_tracker_v2_enter("jagged-verifier");
448 recursive_jagged_verifier.verify_trusted_evaluations(
449 &mut builder,
450 &commitments_var,
451 eval_point_var,
452 &evaluation_claims_var,
453 &proof_var,
454 &mut challenger_variable,
455 );
456 builder.cycle_tracker_v2_exit();
457
458 let block = builder.into_root_block();
459 let mut compiler = AsmCompiler::default();
460
461 let program = compiler.compile_inner(block).validate().unwrap();
463
464 let mut witness_stream = Vec::new();
466 Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
467 Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
468 Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
469 Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
470 let mut executor =
471 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program.clone()), inner_perm());
472 executor.witness_stream = witness_stream.into();
473 executor.run().unwrap();
474
475 let mut witness_stream = Vec::new();
477 commitments.rounds[0][0] += F::one();
478 Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
479 Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
480 Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
481 Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
482 let mut executor =
483 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
484 executor.witness_stream = witness_stream.into();
485 executor.run().expect_err("invalid proof should not be verified");
486 }
487}