1use std::marker::PhantomData;
2
3use itertools::izip;
4use slop_algebra::AbstractField;
5use slop_jagged::{JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof};
6use slop_multilinear::{Mle, MleEval, Point};
7use slop_sumcheck::PartialSumcheckProof;
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10 circuit::CircuitV2Builder,
11 ir::{Builder, Ext, Felt, SymbolicExt},
12};
13
14use crate::{
15 basefold::{
16 stacked::{RecursiveStackedPcsProof, RecursiveStackedPcsVerifier},
17 RecursiveBasefoldProof, RecursiveBasefoldVerifier,
18 },
19 challenger::FieldChallengerVariable,
20 sumcheck::{evaluate_mle_ext, verify_sumcheck},
21 CircuitConfig, SP1FieldConfigVariable,
22};
23
24use super::jagged_eval::{RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig};
25
26pub struct RecursivePcsImpl<C, SC, P> {
27 _marker: PhantomData<(C, SC, P)>,
28}
29
30pub struct JaggedPcsProofVariable<Proof, Digest> {
31 pub params: JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
32 pub sumcheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
33 pub jagged_eval_proof: JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>,
34 pub pcs_proof: RecursiveStackedPcsProof<Proof, SP1Field, SP1ExtensionField>,
35 pub column_counts: Vec<Vec<usize>>,
36 pub row_counts: Vec<Vec<Felt<SP1Field>>>,
37 pub original_commitments: Vec<Digest>,
38 pub expected_eval: Ext<SP1Field, SP1ExtensionField>,
39}
40
41#[derive(Clone)]
42pub struct RecursiveJaggedPcsVerifier<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
43 pub stacked_pcs_verifier: RecursiveStackedPcsVerifier<RecursiveBasefoldVerifier<C, SC>>,
44 pub max_log_row_count: usize,
45 pub jagged_evaluator: RecursiveJaggedEvalSumcheckConfig<SC>,
46}
47
48impl<SC: SP1FieldConfigVariable<C>, C: CircuitConfig> RecursiveJaggedPcsVerifier<SC, C> {
49 #[allow(clippy::too_many_arguments)]
50 pub fn verify_trusted_evaluations(
51 &self,
52 builder: &mut Builder<C>,
53 commitments: &[SC::DigestVariable],
54 point: Point<Ext<SP1Field, SP1ExtensionField>>,
55 evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
56 proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
57 insertion_points: &[usize],
58 challenger: &mut SC::FriChallengerVariable,
59 ) -> Vec<Felt<SP1Field>> {
60 let JaggedPcsProofVariable {
61 pcs_proof,
62 sumcheck_proof,
63 jagged_eval_proof,
64 params,
65 column_counts,
66 original_commitments,
67 expected_eval,
68 ..
69 } = proof;
70 let num_col_variables = (params.col_prefix_sums.len() - 1).next_power_of_two().ilog2();
71
72 let z_col =
73 (0..num_col_variables).map(|_| challenger.sample_ext(builder)).collect::<Point<_>>();
74
75 let z_row = point;
76
77 let mut column_claims = evaluation_claims.iter().flatten().copied().collect::<Vec<_>>();
79
80 let added_columns: Vec<usize> =
81 column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
82 let zero_ext: Ext<SP1Field, SP1ExtensionField> =
87 builder.constant(SP1ExtensionField::zero());
88 for (insertion_point, num_added_columns) in
89 insertion_points.iter().rev().zip(added_columns.iter().rev())
90 {
91 for _ in 0..*num_added_columns {
92 column_claims.insert(*insertion_point, zero_ext);
93 }
94 }
95
96 for (round_column_counts, round_row_counts, modified_commitment, original_commitment) in izip!(
97 column_counts.iter(),
98 proof.row_counts.iter(),
99 commitments.iter(),
100 original_commitments.iter()
101 ) {
102 let mut felts_vec: Vec<Felt<_>> =
103 vec![builder.eval(SP1Field::from_canonical_usize(round_column_counts.len()))];
104 for &count in round_row_counts {
105 felts_vec.push(builder.eval(count));
106 }
107
108 for &count in round_column_counts {
109 felts_vec.push(builder.eval(SP1Field::from_canonical_usize(count)));
110 }
111 let hash = SC::hash(builder, &felts_vec);
112 let expected_commitment = SC::compress(builder, [*original_commitment, hash]);
113
114 SC::assert_digest_eq(builder, expected_commitment, *modified_commitment);
115 }
116
117 column_claims.resize(column_claims.len().next_power_of_two(), zero_ext);
119
120 let column_mle = Mle::from(column_claims);
121 let sumcheck_claim: Ext<SP1Field, SP1ExtensionField> =
122 evaluate_mle_ext(builder, column_mle, z_col.clone())[0];
123
124 builder.assert_ext_eq(sumcheck_claim, sumcheck_proof.claimed_sum);
125
126 builder.cycle_tracker_v2_enter("jagged - verify sumcheck");
127 verify_sumcheck::<C, SC>(builder, challenger, sumcheck_proof);
128 builder.cycle_tracker_v2_exit();
129
130 builder.cycle_tracker_v2_enter("jagged - jagged-eval");
131 let (jagged_eval, prefix_sum_felts) = self.jagged_evaluator.jagged_evaluation(
132 builder,
133 params,
134 z_row,
135 z_col,
136 sumcheck_proof.point_and_eval.0.clone(),
137 jagged_eval_proof,
138 challenger,
139 );
140 builder.cycle_tracker_v2_exit();
141
142 builder.assert_ext_eq(jagged_eval * *expected_eval, sumcheck_proof.point_and_eval.1);
144
145 let evaluation_point = sumcheck_proof.point_and_eval.0.clone();
147 self.stacked_pcs_verifier.verify_untrusted_evaluation(
148 builder,
149 original_commitments,
150 &evaluation_point,
151 pcs_proof,
152 SymbolicExt::from(*expected_eval),
153 challenger,
154 );
155 prefix_sum_felts
156 }
157}
158
159pub struct RecursiveMachineJaggedPcsVerifier<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig> {
160 pub jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
161 pub column_counts_by_round: Vec<Vec<usize>>,
162}
163
164impl<'a, SC: SP1FieldConfigVariable<C>, C: CircuitConfig>
165 RecursiveMachineJaggedPcsVerifier<'a, SC, C>
166{
167 pub fn new(
168 jagged_pcs_verifier: &'a RecursiveJaggedPcsVerifier<SC, C>,
169 column_counts_by_round: Vec<Vec<usize>>,
170 ) -> Self {
171 Self { jagged_pcs_verifier, column_counts_by_round }
172 }
173
174 pub fn verify_trusted_evaluations(
175 &self,
176 builder: &mut Builder<C>,
177 commitments: &[SC::DigestVariable],
178 point: Point<Ext<SP1Field, SP1ExtensionField>>,
179 evaluation_claims: &[MleEval<Ext<SP1Field, SP1ExtensionField>>],
180 proof: &JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
181 challenger: &mut SC::FriChallengerVariable,
182 ) -> Vec<Felt<SP1Field>> {
183 let insertion_points = self
184 .column_counts_by_round
185 .iter()
186 .scan(0, |state, y| {
187 *state += y.iter().sum::<usize>();
188 Some(*state)
189 })
190 .collect::<Vec<_>>();
191
192 self.jagged_pcs_verifier.verify_trusted_evaluations(
193 builder,
194 commitments,
195 point,
196 evaluation_claims,
197 proof,
198 &insertion_points,
199 challenger,
200 )
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use std::{marker::PhantomData, sync::Arc};
207
208 use rand::{thread_rng, Rng};
209 use slop_algebra::AbstractField;
210 use slop_basefold::{BasefoldVerifier, FriConfig};
211 use slop_challenger::{CanObserve, IopCtx};
212 use slop_commit::Rounds;
213 use slop_jagged::{JaggedPcsProof, JaggedPcsVerifier, JaggedProver};
214 use slop_multilinear::{Evaluations, Mle, MleEval, PaddedMle, Point};
215 use sp1_core_machine::utils::setup_logger;
216 use sp1_hypercube::{
217 inner_perm, prover::SP1InnerPcsProver, SP1InnerPcs, SP1PcsProof, SP1PcsProofInner,
218 };
219 use sp1_primitives::{SP1DiffusionMatrix, SP1ExtensionField, SP1Field, SP1GlobalContext};
220 use sp1_recursion_compiler::circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder};
221 use sp1_recursion_executor::Executor;
222
223 use crate::{
224 basefold::{
225 stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs,
226 RecursiveBasefoldVerifier,
227 },
228 challenger::{CanObserveVariable, DuplexChallengerVariable},
229 jagged::{
230 jagged_eval::RecursiveJaggedEvalSumcheckConfig,
231 verifier::{RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier},
232 },
233 witness::Witnessable,
234 };
235
236 type SC = SP1GlobalContext;
237 type JC = SP1InnerPcs;
238 type GC = SP1GlobalContext;
239 type F = SP1Field;
240 type EF = SP1ExtensionField;
241 type C = AsmConfig;
242 type Prover = JaggedProver<SP1GlobalContext, SP1PcsProofInner, SP1InnerPcsProver>;
243
244 #[allow(clippy::type_complexity)]
245 fn generate_jagged_proof(
246 jagged_verifier: &JaggedPcsVerifier<GC, JC>,
247 round_mles: Rounds<Vec<PaddedMle<F>>>,
248 eval_point: Point<EF>,
249 ) -> (
250 JaggedPcsProof<GC, SP1PcsProof<GC>>,
251 Rounds<<GC as IopCtx>::Digest>,
252 Rounds<Evaluations<EF>>,
253 ) {
254 let jagged_prover = Prover::from_verifier(jagged_verifier);
255
256 let mut challenger = jagged_verifier.challenger();
257
258 let mut prover_data = Rounds::new();
259 let mut commitments = Rounds::new();
260 for round in round_mles.iter() {
261 let (commit, data) = jagged_prover.commit_multilinears(round.clone()).ok().unwrap();
262 challenger.observe(commit);
263 let data_bytes = bincode::serialize(&data).unwrap();
264 let data = bincode::deserialize(&data_bytes).unwrap();
265 prover_data.push(data);
266 commitments.push(commit);
267 }
268
269 let mut evaluation_claims = Rounds::new();
270 for round in round_mles.iter() {
271 let mut evals = Evaluations::default();
272 for mle in round.iter() {
273 let eval = mle.eval_at(&eval_point);
274 evals.push(eval);
275 }
276 evaluation_claims.push(evals);
277 }
278
279 let proof = jagged_prover
280 .prove_trusted_evaluations(
281 eval_point.clone(),
282 evaluation_claims.clone(),
283 prover_data,
284 &mut challenger,
285 )
286 .ok()
287 .unwrap();
288
289 (proof, commitments, evaluation_claims)
290 }
291
292 #[test]
293 fn test_jagged_verifier() {
294 setup_logger();
295
296 let row_counts_rounds = vec![
297 vec![
298 1 << 13,
299 1 << 8,
300 1 << 11,
301 1 << 7,
302 1 << 16,
303 1 << 14,
304 1 << 20,
305 1 << 7,
306 1 << 9,
307 1 << 11,
308 1 << 8,
309 1 << 7,
310 1 << 14,
311 1 << 10,
312 1 << 14,
313 1 << 8,
314 ],
315 vec![1 << 8],
316 ];
317 let column_counts_rounds = vec![
318 vec![47, 41, 41, 58, 52, 109, 428, 50, 53, 93, 100, 83, 31, 68, 134, 80],
319 vec![512],
320 ];
321
322 let num_rounds = row_counts_rounds.len();
323
324 let log_stacking_height = 21;
325 let max_log_row_count = 20;
326
327 let row_counts = row_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
328 let column_counts = column_counts_rounds.into_iter().collect::<Rounds<Vec<usize>>>();
329
330 assert!(row_counts.len() == column_counts.len());
331
332 let mut rng = thread_rng();
333
334 let round_mles = row_counts
335 .iter()
336 .zip(column_counts.iter())
337 .map(|(row_counts, col_counts)| {
338 row_counts
339 .iter()
340 .zip(col_counts.iter())
341 .map(|(num_rows, num_cols)| {
342 if *num_rows == 0 {
343 PaddedMle::zeros(*num_cols, max_log_row_count)
344 } else {
345 let mle = Mle::<F>::rand(&mut rng, *num_cols, num_rows.ilog(2));
346 PaddedMle::padded_with_zeros(Arc::new(mle), max_log_row_count)
347 }
348 })
349 .collect::<Vec<_>>()
350 })
351 .collect::<Rounds<_>>();
352
353 let jagged_verifier = JaggedPcsVerifier::<GC, JC>::new_from_basefold_params(
354 FriConfig::default_fri_config(),
355 log_stacking_height,
356 max_log_row_count as usize,
357 num_rounds,
358 );
359
360 let eval_point = (0..max_log_row_count).map(|_| rng.gen::<EF>()).collect::<Point<_>>();
361
362 let (proof, mut commitments, evaluation_claims) =
364 generate_jagged_proof(&jagged_verifier, round_mles, eval_point.clone());
365
366 let mut challenger = jagged_verifier.challenger();
367
368 for commitment in commitments.iter() {
369 challenger.observe(*commitment);
371 }
372
373 let evaluation_claims = evaluation_claims
374 .iter()
375 .map(|round| {
376 round.iter().flat_map(|evals| evals.iter().cloned()).collect::<MleEval<_>>()
377 })
378 .collect::<Vec<_>>();
379
380 jagged_verifier
381 .verify_trusted_evaluations(
382 &commitments,
383 eval_point.clone(),
384 &evaluation_claims,
385 &proof,
386 &mut challenger,
387 )
388 .unwrap();
389
390 let mut builder = AsmBuilder::default();
392 builder.cycle_tracker_v2_enter("jagged - read input");
393 let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
394 let commitments_var = commitments.read(&mut builder);
395 let eval_point_var = eval_point.read(&mut builder);
396 let evaluation_claims_var = evaluation_claims.read(&mut builder);
397 let proof_var = proof.read(&mut builder);
398 builder.cycle_tracker_v2_exit();
399 builder.cycle_tracker_v2_enter("jagged - observe commitments");
400 for commitment_var in commitments_var.iter() {
401 challenger_variable.observe_slice(&mut builder, *commitment_var);
402 }
403 builder.cycle_tracker_v2_exit();
404 let verifier = BasefoldVerifier::<SC>::new(FriConfig::default_fri_config(), num_rounds);
405 let recursive_verifier = RecursiveBasefoldVerifier::<C, SC> {
406 fri_config: verifier.fri_config,
407 tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
408 };
409 let recursive_verifier =
410 RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
411
412 let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<SC, C> {
413 stacked_pcs_verifier: recursive_verifier,
414 max_log_row_count: max_log_row_count as usize,
415 jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<SP1GlobalContext>(PhantomData),
416 };
417
418 let recursive_jagged_verifier = RecursiveMachineJaggedPcsVerifier::new(
419 &recursive_jagged_verifier,
420 vec![column_counts[0].clone(), column_counts[1].clone()],
421 );
422
423 builder.cycle_tracker_v2_enter("jagged-verifier");
424 recursive_jagged_verifier.verify_trusted_evaluations(
425 &mut builder,
426 &commitments_var,
427 eval_point_var,
428 &evaluation_claims_var,
429 &proof_var,
430 &mut challenger_variable,
431 );
432 builder.cycle_tracker_v2_exit();
433
434 let block = builder.into_root_block();
435 let mut compiler = AsmCompiler::default();
436
437 let program = compiler.compile_inner(block).validate().unwrap();
439
440 let mut witness_stream = Vec::new();
442 Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
443 Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
444 Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
445 Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
446 let mut executor =
447 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program.clone()), inner_perm());
448 executor.witness_stream = witness_stream.into();
449 executor.run().unwrap();
450
451 let mut witness_stream = Vec::new();
453 commitments.rounds[0][0] += F::one();
454 Witnessable::<AsmConfig>::write(&commitments, &mut witness_stream);
455 Witnessable::<AsmConfig>::write(&eval_point, &mut witness_stream);
456 Witnessable::<AsmConfig>::write(&evaluation_claims, &mut witness_stream);
457 Witnessable::<AsmConfig>::write(&proof, &mut witness_stream);
458 let mut executor =
459 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
460 executor.witness_stream = witness_stream.into();
461 executor.run().expect_err("invalid proof should not be verified");
462 }
463}