1use crate::{
2 basefold::RecursiveBasefoldProof,
3 challenger::CanObserveVariable,
4 jagged::{
5 JaggedPcsProofVariable, RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier,
6 },
7 logup_gkr::RecursiveLogUpGkrVerifier,
8 zerocheck::RecursiveVerifierConstraintFolder,
9 CircuitConfig, SP1FieldConfigVariable,
10};
11use slop_air::Air;
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_commit::Rounds;
15use slop_multilinear::{Evaluations, MleEval};
16use slop_sumcheck::PartialSumcheckProof;
17
18use sp1_hypercube::{
19 air::MachineAir, septic_digest::SepticDigest, GenericVerifierPublicValuesConstraintFolder,
20 LogupGkrProof, Machine, ShardOpenedValues, UntrustedConfig,
21};
22use sp1_primitives::{SP1ExtensionField, SP1Field};
23use sp1_recursion_compiler::{
24 circuit::CircuitV2Builder,
25 ir::{Builder, Felt, SymbolicExt},
26 prelude::{Ext, SymbolicFelt},
27};
28use sp1_recursion_executor::{DIGEST_SIZE, NUM_BITS};
29use std::collections::{BTreeMap, BTreeSet};
30
31#[allow(clippy::type_complexity)]
32pub struct ShardProofVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C> + Send + Sync> {
33 pub main_commitment: SC::DigestVariable,
35 pub opened_values: ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
37 pub zerocheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
39 pub public_values: Vec<Felt<SP1Field>>,
41 pub logup_gkr_proof: LogupGkrProof<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
43 pub evaluation_proof: JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
45}
46
47pub struct MachineVerifyingKeyVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C>> {
48 pub pc_start: [Felt<SP1Field>; 3],
49 pub initial_global_cumulative_sum: SepticDigest<Felt<SP1Field>>,
51 pub preprocessed_commit: SC::DigestVariable,
53 pub untrusted_config: UntrustedConfig<Felt<SP1Field>>,
55}
56impl<C, SC> MachineVerifyingKeyVariable<C, SC>
57where
58 C: CircuitConfig,
59 SC: SP1FieldConfigVariable<C>,
60{
61 pub fn hash(&self, builder: &mut Builder<C>) -> SC::DigestVariable
65 where
66 SC::DigestVariable: IntoIterator<Item = Felt<SP1Field>>,
67 {
68 #[cfg(not(feature = "mprotect"))]
69 let num_inputs = DIGEST_SIZE + 3 + 14 + 1;
70 #[cfg(feature = "mprotect")]
71 let num_inputs = DIGEST_SIZE + 3 + 14 + 1 + 1 + 9 + 6;
72 let mut inputs = Vec::with_capacity(num_inputs);
73 inputs.extend(self.preprocessed_commit);
74 inputs.extend(self.pc_start);
75 inputs.extend(self.initial_global_cumulative_sum.0.x.0);
76 inputs.extend(self.initial_global_cumulative_sum.0.y.0);
77 inputs.push(self.untrusted_config.enable_untrusted_programs);
78 #[cfg(feature = "mprotect")]
79 {
80 inputs.push(self.untrusted_config.enable_trap_handler);
81 inputs.extend(self.untrusted_config.trap_context.as_flattened());
82 inputs.extend(self.untrusted_config.untrusted_memory.as_flattened());
83 }
84
85 SC::hash(builder, &inputs)
86 }
87}
88
89pub struct RecursiveShardVerifier<
91 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
92 A: MachineAir<SP1Field>,
93 C: CircuitConfig,
94> {
95 pub machine: Machine<SP1Field, A>,
97 pub pcs_verifier: RecursiveJaggedPcsVerifier<GC, C>,
99 pub _phantom: std::marker::PhantomData<(GC, C, A)>,
100}
101
102impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
103where
104 GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
105 A: MachineAir<SP1Field>,
106 C: CircuitConfig,
107{
108 pub fn verify_shard(
109 &self,
110 builder: &mut Builder<C>,
111 vk: &MachineVerifyingKeyVariable<C, GC>,
112 proof: &ShardProofVariable<C, GC>,
113 challenger: &mut GC::FriChallengerVariable,
114 ) where
115 A: for<'b> Air<RecursiveVerifierConstraintFolder<'b>>,
116 {
117 let ShardProofVariable {
118 main_commitment,
119 opened_values,
120 evaluation_proof,
121 zerocheck_proof,
122 public_values,
123 logup_gkr_proof,
124 } = proof;
125
126 let heights = opened_values
128 .chips
129 .iter()
130 .map(|(name, x)| (name.clone(), x.degree.clone()))
131 .collect::<BTreeMap<_, _>>();
132 let mut height_felts_map: BTreeMap<String, Felt<SP1Field>> = BTreeMap::new();
133 let two = SymbolicFelt::from_canonical_u32(2);
134 for (name, height) in &heights {
135 let mut acc = SymbolicFelt::zero();
136 assert!(height.len() == self.pcs_verifier.max_log_row_count + 1);
138 height.iter().for_each(|x| {
139 acc = *x + two * acc;
140 });
141 height_felts_map.insert(name.clone(), builder.eval(acc));
142 }
143
144 challenger.observe_slice(builder, public_values.to_vec());
146
147 for value in public_values[self.machine.num_pv_elts()..].iter() {
148 builder.assert_felt_eq(value, GC::F::zero());
149 }
150
151 challenger.observe(builder, *main_commitment);
153 let num_chips: Felt<GC::F> = builder.eval(GC::F::from_canonical_usize(heights.len()));
154 challenger.observe(builder, num_chips);
156
157 for (name, height) in height_felts_map.iter() {
158 challenger.observe(builder, *height);
159 let mut inputs: Vec<Felt<GC::F>> = vec![];
160 inputs.push(builder.eval(GC::F::from_canonical_usize(name.len())));
161 for byte in name.as_bytes() {
162 inputs.push(builder.eval(GC::F::from_canonical_u8(*byte)));
163 }
164 challenger.observe_slice(builder, inputs);
165 }
166
167 let shard_chips = self
168 .machine
169 .chips()
170 .iter()
171 .filter(|chip| heights.contains_key(chip.name()))
172 .cloned()
173 .collect::<BTreeSet<_>>();
174
175 let degrees = opened_values.chips.values().map(|x| x.degree.clone()).collect::<Vec<_>>();
176
177 let max_log_row_count = self.pcs_verifier.max_log_row_count;
178
179 builder.cycle_tracker_v2_enter("verify-logup-gkr");
181 RecursiveLogUpGkrVerifier::<C, GC, A>::verify_logup_gkr(
182 builder,
183 &shard_chips,
184 °rees,
185 max_log_row_count,
186 logup_gkr_proof,
187 public_values,
188 challenger,
189 );
190 builder.cycle_tracker_v2_exit();
191
192 builder.cycle_tracker_v2_enter("verify-zerocheck");
194 self.verify_zerocheck(
195 builder,
196 &shard_chips,
197 opened_values,
198 &logup_gkr_proof.logup_evaluations,
199 zerocheck_proof,
200 public_values,
201 challenger,
202 );
203 builder.cycle_tracker_v2_exit();
204
205 let (preprocessed_openings_for_proof, main_openings_for_proof): (Vec<_>, Vec<_>) = proof
207 .opened_values
208 .chips
209 .values()
210 .map(|opening| (opening.preprocessed.clone(), opening.main.clone()))
211 .unzip();
212
213 let preprocessed_openings = preprocessed_openings_for_proof
214 .iter()
215 .map(|x| x.local.iter().as_slice())
216 .collect::<Vec<_>>();
217
218 let main_openings = main_openings_for_proof
219 .iter()
220 .map(|x| x.local.iter().copied().collect::<MleEval<_>>())
221 .collect::<Evaluations<_>>();
222
223 let filtered_preprocessed_openings = preprocessed_openings
224 .clone()
225 .into_iter()
226 .filter(|x| !x.is_empty())
227 .map(|x| x.iter().copied().collect::<MleEval<_>>())
228 .collect::<Evaluations<_>>();
229
230 let preprocessed_column_count = filtered_preprocessed_openings
231 .iter()
232 .map(|table_openings| table_openings.len())
233 .collect::<Vec<_>>();
234
235 let added_columns: Vec<usize> =
236 proof.evaluation_proof.column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
237
238 let unfiltered_preprocessed_column_count = preprocessed_openings
239 .iter()
240 .map(|table_openings| table_openings.len())
241 .chain(std::iter::once(added_columns[0] - 1))
242 .collect::<Vec<_>>();
243
244 let main_column_count =
245 main_openings.iter().map(|table_openings| table_openings.len()).collect::<Vec<_>>();
246
247 let unfiltered_main_column_count = main_openings
248 .iter()
249 .map(|table_openings| table_openings.len())
250 .chain(std::iter::once(added_columns[1] - 1))
251 .collect::<Vec<_>>();
252
253 let (commitments, column_counts, unfiltered_column_counts, openings) = (
254 vec![vk.preprocessed_commit, *main_commitment],
255 vec![preprocessed_column_count, main_column_count.clone()],
256 vec![unfiltered_preprocessed_column_count, unfiltered_main_column_count],
257 Rounds { rounds: vec![filtered_preprocessed_openings, main_openings] },
258 );
259
260 let machine_jagged_verifier =
261 RecursiveMachineJaggedPcsVerifier::new(&self.pcs_verifier, column_counts.clone());
262
263 let openings = openings
264 .into_iter()
265 .map(|round| {
266 round
267 .into_iter()
268 .flat_map(std::iter::IntoIterator::into_iter)
269 .collect::<MleEval<_>>()
270 })
271 .collect::<Vec<_>>();
272
273 builder.cycle_tracker_v2_enter("jagged-verifier");
274 let prefix_sum_felts = machine_jagged_verifier.verify_trusted_evaluations(
275 builder,
276 &commitments,
277 zerocheck_proof.point_and_eval.0.clone(),
278 &openings,
279 evaluation_proof,
280 challenger,
281 );
282 builder.cycle_tracker_v2_exit();
283
284 let row_count_felt: Felt<_> = builder
285 .constant(SP1Field::from_canonical_u32(1 << self.pcs_verifier.max_log_row_count));
286
287 let params: Vec<Vec<Felt<SP1Field>>> = unfiltered_column_counts
288 .iter()
289 .map(|round| {
290 round
291 .iter()
292 .copied()
293 .zip(height_felts_map.values().copied().chain(std::iter::once(row_count_felt)))
294 .flat_map(|(column_count, height)| {
295 std::iter::repeat_n(height, column_count).collect::<Vec<_>>()
296 })
297 .collect::<Vec<_>>()
298 })
299 .collect();
300
301 let preprocessed_count = params[0].len();
302 let params = params.into_iter().flatten().collect::<Vec<_>>();
303
304 builder.cycle_tracker_v2_enter("jagged - prefix-sum-checks");
305 let mut param_index = 0;
306 let skip_indices = [preprocessed_count];
311
312 prefix_sum_felts
313 .iter()
314 .zip(prefix_sum_felts.iter().skip(1))
315 .enumerate()
316 .filter(|(i, _)| !skip_indices.contains(i))
317 .for_each(|(_, (x, y))| {
318 let sum = *x + params[param_index];
319 builder.assert_felt_eq(sum, *y);
320 param_index += 1;
321 });
322
323 builder.assert_felt_eq(prefix_sum_felts[0], SP1Field::zero());
324
325 builder.assert_felt_eq(
327 prefix_sum_felts[skip_indices[0] + 1],
328 SP1Field::from_canonical_usize(
329 (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
330 * evaluation_proof.pcs_proof.batch_evaluations.rounds[0].num_polynomials(),
331 ),
332 );
333
334 let preprocessed_padding_col_height =
335 builder.eval(prefix_sum_felts[skip_indices[0] + 1] - prefix_sum_felts[skip_indices[0]]);
336 let preprocessed_padding_col_bit_decomp = C::num2bits(
337 builder,
338 preprocessed_padding_col_height,
339 self.pcs_verifier.max_log_row_count + 1,
340 );
341
342 let max_bit = preprocessed_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
353 let max_bit = C::bits2num(builder, vec![max_bit]);
354 let zero: Felt<_> = builder.constant(SP1Field::zero());
355 for bit in
356 preprocessed_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count)
357 {
358 let bit_felt = C::bits2num(builder, vec![*bit]);
359 builder.assert_felt_eq(max_bit * bit_felt, zero);
360 }
361 let num_cols = prefix_sum_felts.len();
362
363 let main_padding_col_height =
365 builder.eval(prefix_sum_felts[num_cols - 1] - prefix_sum_felts[num_cols - 2]);
366
367 let main_padding_col_bit_decomp = C::num2bits(builder, main_padding_col_height, NUM_BITS);
368
369 let max_bit = main_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
370 let max_bit = C::bits2num(builder, vec![max_bit]);
371 for bit in main_padding_col_bit_decomp.iter().skip(self.pcs_verifier.max_log_row_count + 1)
372 {
373 C::assert_bit_zero(builder, *bit);
374 }
375 for bit in main_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count) {
376 let bit_felt = C::bits2num(builder, vec![*bit]);
377 builder.assert_felt_eq(max_bit * bit_felt, zero);
378 }
379
380 let total_area_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(
382 (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
383 * proof
384 .evaluation_proof
385 .pcs_proof
386 .batch_evaluations
387 .iter()
388 .map(|evaluations| evaluations.num_polynomials())
389 .sum::<usize>(),
390 ));
391
392 let mut acc = SymbolicFelt::zero();
394 proof.evaluation_proof.params.col_prefix_sums.iter().last().unwrap().iter().for_each(|x| {
396 acc = *x + two * acc;
397 });
398
399 builder.assert_felt_eq(acc, total_area_felt);
401
402 builder.cycle_tracker_v2_exit();
403 }
404}
405
406pub type RecursiveVerifierPublicValuesConstraintFolder<'a> =
407 GenericVerifierPublicValuesConstraintFolder<
408 'a,
409 SP1Field,
410 SP1ExtensionField,
411 Felt<SP1Field>,
412 Ext<SP1Field, SP1ExtensionField>,
413 SymbolicExt<SP1Field, SP1ExtensionField>,
414 >;
415
416#[cfg(test)]
417mod tests {
418 use std::{marker::PhantomData, sync::Arc};
419
420 use slop_basefold::{BasefoldVerifier, FriConfig};
421 use sp1_core_executor::{Program, SP1Context, SP1CoreOpts};
422 use sp1_core_machine::{
423 io::SP1Stdin,
424 riscv::RiscvAir,
425 utils::{prove_core, setup_logger},
426 };
427 use sp1_hypercube::{
428 prover::{CpuShardProver, SP1InnerPcsProver, SimpleProver},
429 MachineVerifier, SP1InnerPcs, ShardVerifier, NUM_SP1_COMMITMENTS,
430 };
431 use sp1_recursion_compiler::{
432 circuit::{AsmCompiler, AsmConfig},
433 config::InnerConfig,
434 };
435 use sp1_recursion_machine::test::run_recursion_test_machines;
436
437 use crate::{
438 basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
439 challenger::DuplexChallengerVariable,
440 dummy::dummy_shard_proof,
441 jagged::RecursiveJaggedEvalSumcheckConfig,
442 witness::Witnessable,
443 };
444
445 use super::*;
446
447 use sp1_primitives::{SP1Field, SP1GlobalContext};
448 type GC = SP1GlobalContext;
449 type C = InnerConfig;
450 type A = RiscvAir<SP1Field>;
451
452 #[tokio::test]
453 async fn test_verify_shard() {
454 setup_logger();
455 let log_stacking_height = 21;
456 let max_log_row_count = 22;
457 let machine = RiscvAir::machine();
458 let verifier = ShardVerifier::from_basefold_parameters(
459 FriConfig::default_fri_config(),
460 log_stacking_height,
461 max_log_row_count,
462 machine.clone(),
463 );
464
465 let elf = test_artifacts::FIBONACCI_ELF;
466 let program = Arc::new(Program::from(&elf).unwrap());
467 let shard_prover =
468 CpuShardProver::<SP1GlobalContext, SP1InnerPcs, SP1InnerPcsProver, _>::new(
469 verifier.clone(),
470 );
471 let prover = SimpleProver::new(verifier.clone(), shard_prover);
472
473 let (pk, vk) = prover.setup(program.clone()).await;
474 let pk = unsafe { pk.into_inner() };
475 let (proof, _) = prove_core(
476 &prover,
477 pk,
478 program,
479 SP1Stdin::default(),
480 SP1CoreOpts::default(),
481 SP1Context::default(),
482 )
483 .await
484 .unwrap();
485
486 let mut builder = Builder::<C>::default();
487
488 let mut initial_challenger = verifier.jagged_pcs_verifier.challenger();
491 vk.observe_into(&mut initial_challenger);
492
493 let machine_verifier = MachineVerifier::new(verifier);
494 machine_verifier.verify(&vk, &proof).unwrap();
495
496 let shard_proof = proof.shard_proofs[0].clone();
497 let shape = machine_verifier.shape_from_proof(&shard_proof);
498
499 let dummy_proof = dummy_shard_proof(
500 shape.shard_chips,
501 max_log_row_count,
502 FriConfig::default_fri_config(),
503 log_stacking_height as usize,
504 &[
505 shape.preprocessed_area >> log_stacking_height,
506 shape.main_area >> log_stacking_height,
507 ],
508 &[shape.preprocessed_padding_cols, shape.main_padding_cols],
509 );
510
511 let vk_variable = vk.read(&mut builder);
512 let shard_proof_variable = dummy_proof.read(&mut builder);
513
514 let verifier =
515 BasefoldVerifier::<GC>::new(FriConfig::default_fri_config(), NUM_SP1_COMMITMENTS);
516 let recursive_verifier = crate::basefold::RecursiveBasefoldVerifier::<C, GC> {
517 fri_config: verifier.fri_config,
518 tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
519 };
520 let recursive_verifier =
521 RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
522
523 let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<GC, C> {
524 stacked_pcs_verifier: recursive_verifier,
525 max_log_row_count,
526 jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
527 };
528
529 let stark_verifier = RecursiveShardVerifier::<GC, A, C> {
530 machine,
531 pcs_verifier: recursive_jagged_verifier,
532 _phantom: std::marker::PhantomData,
533 };
534
535 let mut challenger_variable =
536 DuplexChallengerVariable::from_challenger(&mut builder, &initial_challenger);
537
538 builder.cycle_tracker_v2_enter("verify-shard");
539 stark_verifier.verify_shard(
540 &mut builder,
541 &vk_variable,
542 &shard_proof_variable,
543 &mut challenger_variable,
544 );
545 builder.cycle_tracker_v2_exit();
546
547 let block = builder.into_root_block();
548 let mut compiler = AsmCompiler::default();
549 let program = compiler.compile_inner(block).validate().unwrap();
550
551 let mut witness_stream = Vec::new();
552 Witnessable::<AsmConfig>::write(&vk, &mut witness_stream);
553 Witnessable::<AsmConfig>::write(&shard_proof, &mut witness_stream);
554
555 run_recursion_test_machines(program.clone(), witness_stream).await;
556 }
557}