Skip to main content

sp1_recursion_circuit/
shard.rs

1use crate::{
2    basefold::RecursiveBasefoldProof,
3    challenger::CanObserveVariable,
4    jagged::{
5        JaggedPcsProofVariable, RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier,
6    },
7    logup_gkr::RecursiveLogUpGkrVerifier,
8    zerocheck::RecursiveVerifierConstraintFolder,
9    CircuitConfig, SP1FieldConfigVariable,
10};
11use slop_air::Air;
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_commit::Rounds;
15use slop_multilinear::{Evaluations, MleEval};
16use slop_sumcheck::PartialSumcheckProof;
17
18use sp1_hypercube::{
19    air::MachineAir, septic_digest::SepticDigest, GenericVerifierPublicValuesConstraintFolder,
20    LogupGkrProof, Machine, ShardOpenedValues, UntrustedConfig,
21};
22use sp1_primitives::{SP1ExtensionField, SP1Field};
23use sp1_recursion_compiler::{
24    circuit::CircuitV2Builder,
25    ir::{Builder, Felt, SymbolicExt},
26    prelude::{Ext, SymbolicFelt},
27};
28use sp1_recursion_executor::{DIGEST_SIZE, NUM_BITS};
29use std::collections::{BTreeMap, BTreeSet};
30
31#[allow(clippy::type_complexity)]
32pub struct ShardProofVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C> + Send + Sync> {
33    /// The commitments to main traces.
34    pub main_commitment: SC::DigestVariable,
35    /// The values of the traces at the final random point.
36    pub opened_values: ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
37    /// The zerocheck IOP proof.
38    pub zerocheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
39    /// The public values
40    pub public_values: Vec<Felt<SP1Field>>,
41    // TODO: The `LogUp+GKR` IOP proofs.
42    pub logup_gkr_proof: LogupGkrProof<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
43    /// The evaluation proof.
44    pub evaluation_proof: JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
45}
46
47pub struct MachineVerifyingKeyVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C>> {
48    pub pc_start: [Felt<SP1Field>; 3],
49    /// The starting global digest of the program, after incorporating the initial memory.
50    pub initial_global_cumulative_sum: SepticDigest<Felt<SP1Field>>,
51    /// The preprocessed commitments.
52    pub preprocessed_commit: SC::DigestVariable,
53    /// Metadata on configuration regarding untrusted programs.
54    pub untrusted_config: UntrustedConfig<Felt<SP1Field>>,
55}
56impl<C, SC> MachineVerifyingKeyVariable<C, SC>
57where
58    C: CircuitConfig,
59    SC: SP1FieldConfigVariable<C>,
60{
61    /// Hash the verifying key + prep domains into a single digest.
62    /// poseidon2(commit[0..8] || pc_start || initial_global_cumulative_sum ||
63    /// height || name)
64    pub fn hash(&self, builder: &mut Builder<C>) -> SC::DigestVariable
65    where
66        SC::DigestVariable: IntoIterator<Item = Felt<SP1Field>>,
67    {
68        #[cfg(not(feature = "mprotect"))]
69        let num_inputs = DIGEST_SIZE + 3 + 14 + 1;
70        #[cfg(feature = "mprotect")]
71        let num_inputs = DIGEST_SIZE + 3 + 14 + 1 + 1 + 9 + 6;
72        let mut inputs = Vec::with_capacity(num_inputs);
73        inputs.extend(self.preprocessed_commit);
74        inputs.extend(self.pc_start);
75        inputs.extend(self.initial_global_cumulative_sum.0.x.0);
76        inputs.extend(self.initial_global_cumulative_sum.0.y.0);
77        inputs.push(self.untrusted_config.enable_untrusted_programs);
78        #[cfg(feature = "mprotect")]
79        {
80            inputs.push(self.untrusted_config.enable_trap_handler);
81            inputs.extend(self.untrusted_config.trap_context.as_flattened());
82            inputs.extend(self.untrusted_config.untrusted_memory.as_flattened());
83        }
84
85        SC::hash(builder, &inputs)
86    }
87}
88
89/// A verifier for shard proofs.
90pub struct RecursiveShardVerifier<
91    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
92    A: MachineAir<SP1Field>,
93    C: CircuitConfig,
94> {
95    /// The machine.
96    pub machine: Machine<SP1Field, A>,
97    /// The jagged pcs verifier.
98    pub pcs_verifier: RecursiveJaggedPcsVerifier<GC, C>,
99    pub _phantom: std::marker::PhantomData<(GC, C, A)>,
100}
101
102impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
103where
104    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
105    A: MachineAir<SP1Field>,
106    C: CircuitConfig,
107{
108    pub fn verify_shard(
109        &self,
110        builder: &mut Builder<C>,
111        vk: &MachineVerifyingKeyVariable<C, GC>,
112        proof: &ShardProofVariable<C, GC>,
113        challenger: &mut GC::FriChallengerVariable,
114    ) where
115        A: for<'b> Air<RecursiveVerifierConstraintFolder<'b>>,
116    {
117        let ShardProofVariable {
118            main_commitment,
119            opened_values,
120            evaluation_proof,
121            zerocheck_proof,
122            public_values,
123            logup_gkr_proof,
124        } = proof;
125
126        // Convert height bits to felts.
127        let heights = opened_values
128            .chips
129            .iter()
130            .map(|(name, x)| (name.clone(), x.degree.clone()))
131            .collect::<BTreeMap<_, _>>();
132        let mut height_felts_map: BTreeMap<String, Felt<SP1Field>> = BTreeMap::new();
133        let two = SymbolicFelt::from_canonical_u32(2);
134        for (name, height) in &heights {
135            let mut acc = SymbolicFelt::zero();
136            // Assert max height to avoid overflow during prefix-sum-checks.
137            assert!(height.len() == self.pcs_verifier.max_log_row_count + 1);
138            height.iter().for_each(|x| {
139                acc = *x + two * acc;
140            });
141            height_felts_map.insert(name.clone(), builder.eval(acc));
142        }
143
144        // Observe the public values.
145        challenger.observe_slice(builder, public_values.to_vec());
146
147        for value in public_values[self.machine.num_pv_elts()..].iter() {
148            builder.assert_felt_eq(value, GC::F::zero());
149        }
150
151        // Observe the main commitment.
152        challenger.observe(builder, *main_commitment);
153        let num_chips: Felt<GC::F> = builder.eval(GC::F::from_canonical_usize(heights.len()));
154        // Observe the number of chips.
155        challenger.observe(builder, num_chips);
156
157        for (name, height) in height_felts_map.iter() {
158            challenger.observe(builder, *height);
159            let mut inputs: Vec<Felt<GC::F>> = vec![];
160            inputs.push(builder.eval(GC::F::from_canonical_usize(name.len())));
161            for byte in name.as_bytes() {
162                inputs.push(builder.eval(GC::F::from_canonical_u8(*byte)));
163            }
164            challenger.observe_slice(builder, inputs);
165        }
166
167        let shard_chips = self
168            .machine
169            .chips()
170            .iter()
171            .filter(|chip| heights.contains_key(chip.name()))
172            .cloned()
173            .collect::<BTreeSet<_>>();
174
175        let degrees = opened_values.chips.values().map(|x| x.degree.clone()).collect::<Vec<_>>();
176
177        let max_log_row_count = self.pcs_verifier.max_log_row_count;
178
179        // Verify the `LogUp` GKR proof.
180        builder.cycle_tracker_v2_enter("verify-logup-gkr");
181        RecursiveLogUpGkrVerifier::<C, GC, A>::verify_logup_gkr(
182            builder,
183            &shard_chips,
184            &degrees,
185            max_log_row_count,
186            logup_gkr_proof,
187            public_values,
188            challenger,
189        );
190        builder.cycle_tracker_v2_exit();
191
192        // Verify the zerocheck proof.
193        builder.cycle_tracker_v2_enter("verify-zerocheck");
194        self.verify_zerocheck(
195            builder,
196            &shard_chips,
197            opened_values,
198            &logup_gkr_proof.logup_evaluations,
199            zerocheck_proof,
200            public_values,
201            challenger,
202        );
203        builder.cycle_tracker_v2_exit();
204
205        // Verify the opening proof.
206        let (preprocessed_openings_for_proof, main_openings_for_proof): (Vec<_>, Vec<_>) = proof
207            .opened_values
208            .chips
209            .values()
210            .map(|opening| (opening.preprocessed.clone(), opening.main.clone()))
211            .unzip();
212
213        let preprocessed_openings = preprocessed_openings_for_proof
214            .iter()
215            .map(|x| x.local.iter().as_slice())
216            .collect::<Vec<_>>();
217
218        let main_openings = main_openings_for_proof
219            .iter()
220            .map(|x| x.local.iter().copied().collect::<MleEval<_>>())
221            .collect::<Evaluations<_>>();
222
223        let filtered_preprocessed_openings = preprocessed_openings
224            .clone()
225            .into_iter()
226            .filter(|x| !x.is_empty())
227            .map(|x| x.iter().copied().collect::<MleEval<_>>())
228            .collect::<Evaluations<_>>();
229
230        let preprocessed_column_count = filtered_preprocessed_openings
231            .iter()
232            .map(|table_openings| table_openings.len())
233            .collect::<Vec<_>>();
234
235        let added_columns: Vec<usize> =
236            proof.evaluation_proof.column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
237
238        let unfiltered_preprocessed_column_count = preprocessed_openings
239            .iter()
240            .map(|table_openings| table_openings.len())
241            .chain(std::iter::once(added_columns[0] - 1))
242            .collect::<Vec<_>>();
243
244        let main_column_count =
245            main_openings.iter().map(|table_openings| table_openings.len()).collect::<Vec<_>>();
246
247        let unfiltered_main_column_count = main_openings
248            .iter()
249            .map(|table_openings| table_openings.len())
250            .chain(std::iter::once(added_columns[1] - 1))
251            .collect::<Vec<_>>();
252
253        let (commitments, column_counts, unfiltered_column_counts, openings) = (
254            vec![vk.preprocessed_commit, *main_commitment],
255            vec![preprocessed_column_count, main_column_count.clone()],
256            vec![unfiltered_preprocessed_column_count, unfiltered_main_column_count],
257            Rounds { rounds: vec![filtered_preprocessed_openings, main_openings] },
258        );
259
260        let machine_jagged_verifier =
261            RecursiveMachineJaggedPcsVerifier::new(&self.pcs_verifier, column_counts.clone());
262
263        let openings = openings
264            .into_iter()
265            .map(|round| {
266                round
267                    .into_iter()
268                    .flat_map(std::iter::IntoIterator::into_iter)
269                    .collect::<MleEval<_>>()
270            })
271            .collect::<Vec<_>>();
272
273        builder.cycle_tracker_v2_enter("jagged-verifier");
274        let prefix_sum_felts = machine_jagged_verifier.verify_trusted_evaluations(
275            builder,
276            &commitments,
277            zerocheck_proof.point_and_eval.0.clone(),
278            &openings,
279            evaluation_proof,
280            challenger,
281        );
282        builder.cycle_tracker_v2_exit();
283
284        let row_count_felt: Felt<_> = builder
285            .constant(SP1Field::from_canonical_u32(1 << self.pcs_verifier.max_log_row_count));
286
287        let params: Vec<Vec<Felt<SP1Field>>> = unfiltered_column_counts
288            .iter()
289            .map(|round| {
290                round
291                    .iter()
292                    .copied()
293                    .zip(height_felts_map.values().copied().chain(std::iter::once(row_count_felt)))
294                    .flat_map(|(column_count, height)| {
295                        std::iter::repeat_n(height, column_count).collect::<Vec<_>>()
296                    })
297                    .collect::<Vec<_>>()
298            })
299            .collect();
300
301        let preprocessed_count = params[0].len();
302        let params = params.into_iter().flatten().collect::<Vec<_>>();
303
304        builder.cycle_tracker_v2_enter("jagged - prefix-sum-checks");
305        let mut param_index = 0;
306        // The prefix_sum_felts coming from the C::prefix_sum_checks call excludes what is the last
307        // element, namely the total area, in the Rust verifier. We add that check in manually
308        // below. That is why the Rust verifier `skip_indices` has two elements, while this
309        // one has one.
310        let skip_indices = [preprocessed_count];
311
312        prefix_sum_felts
313            .iter()
314            .zip(prefix_sum_felts.iter().skip(1))
315            .enumerate()
316            .filter(|(i, _)| !skip_indices.contains(i))
317            .for_each(|(_, (x, y))| {
318                let sum = *x + params[param_index];
319                builder.assert_felt_eq(sum, *y);
320                param_index += 1;
321            });
322
323        builder.assert_felt_eq(prefix_sum_felts[0], SP1Field::zero());
324
325        // Check that the preprocessed prefix sum is the correct multiple of `stacking_height`.
326        builder.assert_felt_eq(
327            prefix_sum_felts[skip_indices[0] + 1],
328            SP1Field::from_canonical_usize(
329                (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
330                    * evaluation_proof.pcs_proof.batch_evaluations.rounds[0].num_polynomials(),
331            ),
332        );
333
334        let preprocessed_padding_col_height =
335            builder.eval(prefix_sum_felts[skip_indices[0] + 1] - prefix_sum_felts[skip_indices[0]]);
336        let preprocessed_padding_col_bit_decomp = C::num2bits(
337            builder,
338            preprocessed_padding_col_height,
339            self.pcs_verifier.max_log_row_count + 1,
340        );
341
342        // We want to constrain the padding column to be in the range [0, 2^{max_log_row_count}].
343        // The above constraints ensure that the padding column is in the range [0,
344        // 2^{max_log_row_count+1}). The following constraints exclude the range
345        // (2^{max_log_row_count}, 2^{max_log_row_count+1}), namely by ensuring that if the
346        // the `max_log_row_count`-th bit is 1, then the less significant bits must be zero.
347        //
348        // NOTE: Strictly speaking, this is not necessary, since the jagged polynomial will
349        // force a zero evaluation in case any column height is greater than
350        // `2^{max_log_row_count}`, but we add this constraint for extra security, since it
351        // does not have a significant performance impact.
352        let max_bit = preprocessed_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
353        let max_bit = C::bits2num(builder, vec![max_bit]);
354        let zero: Felt<_> = builder.constant(SP1Field::zero());
355        for bit in
356            preprocessed_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count)
357        {
358            let bit_felt = C::bits2num(builder, vec![*bit]);
359            builder.assert_felt_eq(max_bit * bit_felt, zero);
360        }
361        let num_cols = prefix_sum_felts.len();
362
363        // Repeat the process above for the main trace padding column.
364        let main_padding_col_height =
365            builder.eval(prefix_sum_felts[num_cols - 1] - prefix_sum_felts[num_cols - 2]);
366
367        let main_padding_col_bit_decomp = C::num2bits(builder, main_padding_col_height, NUM_BITS);
368
369        let max_bit = main_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
370        let max_bit = C::bits2num(builder, vec![max_bit]);
371        for bit in main_padding_col_bit_decomp.iter().skip(self.pcs_verifier.max_log_row_count + 1)
372        {
373            C::assert_bit_zero(builder, *bit);
374        }
375        for bit in main_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count) {
376            let bit_felt = C::bits2num(builder, vec![*bit]);
377            builder.assert_felt_eq(max_bit * bit_felt, zero);
378        }
379
380        // Compute the total area from the shape of the stacked PCS proof.
381        let total_area_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(
382            (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
383                * proof
384                    .evaluation_proof
385                    .pcs_proof
386                    .batch_evaluations
387                    .iter()
388                    .map(|evaluations| evaluations.num_polynomials())
389                    .sum::<usize>(),
390        ));
391
392        // Convert the final prefix sum to a symbolic felt.
393        let mut acc = SymbolicFelt::zero();
394        // Assert max height to avoid overflow during prefix-sum-checks.
395        proof.evaluation_proof.params.col_prefix_sums.iter().last().unwrap().iter().for_each(|x| {
396            acc = *x + two * acc;
397        });
398
399        // Check equality between the two above-computed values.
400        builder.assert_felt_eq(acc, total_area_felt);
401
402        builder.cycle_tracker_v2_exit();
403    }
404}
405
406pub type RecursiveVerifierPublicValuesConstraintFolder<'a> =
407    GenericVerifierPublicValuesConstraintFolder<
408        'a,
409        SP1Field,
410        SP1ExtensionField,
411        Felt<SP1Field>,
412        Ext<SP1Field, SP1ExtensionField>,
413        SymbolicExt<SP1Field, SP1ExtensionField>,
414    >;
415
416#[cfg(test)]
417mod tests {
418    use std::{marker::PhantomData, sync::Arc};
419
420    use slop_basefold::{BasefoldVerifier, FriConfig};
421    use sp1_core_executor::{Program, SP1Context, SP1CoreOpts};
422    use sp1_core_machine::{
423        io::SP1Stdin,
424        riscv::RiscvAir,
425        utils::{prove_core, setup_logger},
426    };
427    use sp1_hypercube::{
428        prover::{CpuShardProver, SP1InnerPcsProver, SimpleProver},
429        MachineVerifier, SP1InnerPcs, ShardVerifier, NUM_SP1_COMMITMENTS,
430    };
431    use sp1_recursion_compiler::{
432        circuit::{AsmCompiler, AsmConfig},
433        config::InnerConfig,
434    };
435    use sp1_recursion_machine::test::run_recursion_test_machines;
436
437    use crate::{
438        basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
439        challenger::DuplexChallengerVariable,
440        dummy::dummy_shard_proof,
441        jagged::RecursiveJaggedEvalSumcheckConfig,
442        witness::Witnessable,
443    };
444
445    use super::*;
446
447    use sp1_primitives::{SP1Field, SP1GlobalContext};
448    type GC = SP1GlobalContext;
449    type C = InnerConfig;
450    type A = RiscvAir<SP1Field>;
451
452    #[tokio::test]
453    async fn test_verify_shard() {
454        setup_logger();
455        let log_stacking_height = 21;
456        let max_log_row_count = 22;
457        let machine = RiscvAir::machine();
458        let verifier = ShardVerifier::from_basefold_parameters(
459            FriConfig::default_fri_config(),
460            log_stacking_height,
461            max_log_row_count,
462            machine.clone(),
463        );
464
465        let elf = test_artifacts::FIBONACCI_ELF;
466        let program = Arc::new(Program::from(&elf).unwrap());
467        let shard_prover =
468            CpuShardProver::<SP1GlobalContext, SP1InnerPcs, SP1InnerPcsProver, _>::new(
469                verifier.clone(),
470            );
471        let prover = SimpleProver::new(verifier.clone(), shard_prover);
472
473        let (pk, vk) = prover.setup(program.clone()).await;
474        let pk = unsafe { pk.into_inner() };
475        let (proof, _) = prove_core(
476            &prover,
477            pk,
478            program,
479            SP1Stdin::default(),
480            SP1CoreOpts::default(),
481            SP1Context::default(),
482        )
483        .await
484        .unwrap();
485
486        let mut builder = Builder::<C>::default();
487
488        // Get the vk and shard proof from the test artifacts.
489
490        let mut initial_challenger = verifier.jagged_pcs_verifier.challenger();
491        vk.observe_into(&mut initial_challenger);
492
493        let machine_verifier = MachineVerifier::new(verifier);
494        machine_verifier.verify(&vk, &proof).unwrap();
495
496        let shard_proof = proof.shard_proofs[0].clone();
497        let shape = machine_verifier.shape_from_proof(&shard_proof);
498
499        let dummy_proof = dummy_shard_proof(
500            shape.shard_chips,
501            max_log_row_count,
502            FriConfig::default_fri_config(),
503            log_stacking_height as usize,
504            &[
505                shape.preprocessed_area >> log_stacking_height,
506                shape.main_area >> log_stacking_height,
507            ],
508            &[shape.preprocessed_padding_cols, shape.main_padding_cols],
509        );
510
511        let vk_variable = vk.read(&mut builder);
512        let shard_proof_variable = dummy_proof.read(&mut builder);
513
514        let verifier =
515            BasefoldVerifier::<GC>::new(FriConfig::default_fri_config(), NUM_SP1_COMMITMENTS);
516        let recursive_verifier = crate::basefold::RecursiveBasefoldVerifier::<C, GC> {
517            fri_config: verifier.fri_config,
518            tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
519        };
520        let recursive_verifier =
521            RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
522
523        let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<GC, C> {
524            stacked_pcs_verifier: recursive_verifier,
525            max_log_row_count,
526            jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
527        };
528
529        let stark_verifier = RecursiveShardVerifier::<GC, A, C> {
530            machine,
531            pcs_verifier: recursive_jagged_verifier,
532            _phantom: std::marker::PhantomData,
533        };
534
535        let mut challenger_variable =
536            DuplexChallengerVariable::from_challenger(&mut builder, &initial_challenger);
537
538        builder.cycle_tracker_v2_enter("verify-shard");
539        stark_verifier.verify_shard(
540            &mut builder,
541            &vk_variable,
542            &shard_proof_variable,
543            &mut challenger_variable,
544        );
545        builder.cycle_tracker_v2_exit();
546
547        let block = builder.into_root_block();
548        let mut compiler = AsmCompiler::default();
549        let program = compiler.compile_inner(block).validate().unwrap();
550
551        let mut witness_stream = Vec::new();
552        Witnessable::<AsmConfig>::write(&vk, &mut witness_stream);
553        Witnessable::<AsmConfig>::write(&shard_proof, &mut witness_stream);
554
555        run_recursion_test_machines(program.clone(), witness_stream).await;
556    }
557}