Skip to main content

sp1_recursion_circuit/
shard.rs

1use crate::{
2    basefold::RecursiveBasefoldProof,
3    challenger::CanObserveVariable,
4    jagged::{
5        JaggedPcsProofVariable, RecursiveJaggedPcsVerifier, RecursiveMachineJaggedPcsVerifier,
6    },
7    logup_gkr::RecursiveLogUpGkrVerifier,
8    zerocheck::RecursiveVerifierConstraintFolder,
9    CircuitConfig, SP1FieldConfigVariable,
10};
11use slop_air::Air;
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_commit::Rounds;
15use slop_multilinear::{Evaluations, MleEval};
16use slop_sumcheck::PartialSumcheckProof;
17
18use sp1_hypercube::{
19    air::MachineAir, septic_digest::SepticDigest, GenericVerifierPublicValuesConstraintFolder,
20    LogupGkrProof, Machine, ShardOpenedValues,
21};
22use sp1_primitives::{SP1ExtensionField, SP1Field};
23use sp1_recursion_compiler::{
24    circuit::CircuitV2Builder,
25    ir::{Builder, Felt, SymbolicExt},
26    prelude::{Ext, SymbolicFelt},
27};
28use sp1_recursion_executor::{DIGEST_SIZE, NUM_BITS};
29use std::collections::{BTreeMap, BTreeSet};
30
31#[allow(clippy::type_complexity)]
32pub struct ShardProofVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C> + Send + Sync> {
33    /// The commitments to main traces.
34    pub main_commitment: SC::DigestVariable,
35    /// The values of the traces at the final random point.
36    pub opened_values: ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
37    /// The zerocheck IOP proof.
38    pub zerocheck_proof: PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
39    /// The public values
40    pub public_values: Vec<Felt<SP1Field>>,
41    // TODO: The `LogUp+GKR` IOP proofs.
42    pub logup_gkr_proof: LogupGkrProof<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
43    /// The evaluation proof.
44    pub evaluation_proof: JaggedPcsProofVariable<RecursiveBasefoldProof<C, SC>, SC::DigestVariable>,
45}
46
47pub struct MachineVerifyingKeyVariable<C: CircuitConfig, SC: SP1FieldConfigVariable<C>> {
48    pub pc_start: [Felt<SP1Field>; 3],
49    /// The starting global digest of the program, after incorporating the initial memory.
50    pub initial_global_cumulative_sum: SepticDigest<Felt<SP1Field>>,
51    /// The preprocessed commitments.
52    pub preprocessed_commit: SC::DigestVariable,
53    /// Flag indicating if untrusted programs are allowed.
54    pub enable_untrusted_programs: Felt<SP1Field>,
55}
56impl<C, SC> MachineVerifyingKeyVariable<C, SC>
57where
58    C: CircuitConfig,
59    SC: SP1FieldConfigVariable<C>,
60{
61    /// Hash the verifying key + prep domains into a single digest.
62    /// poseidon2(commit[0..8] || pc_start || initial_global_cumulative_sum ||
63    /// height || name)
64    pub fn hash(&self, builder: &mut Builder<C>) -> SC::DigestVariable
65    where
66        SC::DigestVariable: IntoIterator<Item = Felt<SP1Field>>,
67    {
68        let num_inputs = DIGEST_SIZE + 3 + 14 + 1;
69        let mut inputs = Vec::with_capacity(num_inputs);
70        inputs.extend(self.preprocessed_commit);
71        inputs.extend(self.pc_start);
72        inputs.extend(self.initial_global_cumulative_sum.0.x.0);
73        inputs.extend(self.initial_global_cumulative_sum.0.y.0);
74        inputs.push(self.enable_untrusted_programs);
75
76        SC::hash(builder, &inputs)
77    }
78}
79
80/// A verifier for shard proofs.
81pub struct RecursiveShardVerifier<
82    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
83    A: MachineAir<SP1Field>,
84    C: CircuitConfig,
85> {
86    /// The machine.
87    pub machine: Machine<SP1Field, A>,
88    /// The jagged pcs verifier.
89    pub pcs_verifier: RecursiveJaggedPcsVerifier<GC, C>,
90    pub _phantom: std::marker::PhantomData<(GC, C, A)>,
91}
92
93impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
94where
95    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
96    A: MachineAir<SP1Field>,
97    C: CircuitConfig,
98{
99    pub fn verify_shard(
100        &self,
101        builder: &mut Builder<C>,
102        vk: &MachineVerifyingKeyVariable<C, GC>,
103        proof: &ShardProofVariable<C, GC>,
104        challenger: &mut GC::FriChallengerVariable,
105    ) where
106        A: for<'b> Air<RecursiveVerifierConstraintFolder<'b>>,
107    {
108        let ShardProofVariable {
109            main_commitment,
110            opened_values,
111            evaluation_proof,
112            zerocheck_proof,
113            public_values,
114            logup_gkr_proof,
115        } = proof;
116
117        // Convert height bits to felts.
118        let heights = opened_values
119            .chips
120            .iter()
121            .map(|(name, x)| (name.clone(), x.degree.clone()))
122            .collect::<BTreeMap<_, _>>();
123        let mut height_felts_map: BTreeMap<String, Felt<SP1Field>> = BTreeMap::new();
124        let two = SymbolicFelt::from_canonical_u32(2);
125        for (name, height) in &heights {
126            let mut acc = SymbolicFelt::zero();
127            // Assert max height to avoid overflow during prefix-sum-checks.
128            assert!(height.len() == self.pcs_verifier.max_log_row_count + 1);
129            height.iter().for_each(|x| {
130                acc = *x + two * acc;
131            });
132            height_felts_map.insert(name.clone(), builder.eval(acc));
133        }
134
135        // Observe the public values.
136        challenger.observe_slice(builder, public_values.to_vec());
137
138        for value in public_values[self.machine.num_pv_elts()..].iter() {
139            builder.assert_felt_eq(value, GC::F::zero());
140        }
141
142        // Observe the main commitment.
143        challenger.observe(builder, *main_commitment);
144        let num_chips: Felt<GC::F> = builder.eval(GC::F::from_canonical_usize(heights.len()));
145        // Observe the number of chips.
146        challenger.observe(builder, num_chips);
147
148        for (name, height) in height_felts_map.iter() {
149            challenger.observe(builder, *height);
150            let mut inputs: Vec<Felt<GC::F>> = vec![];
151            inputs.push(builder.eval(GC::F::from_canonical_usize(name.len())));
152            for byte in name.as_bytes() {
153                inputs.push(builder.eval(GC::F::from_canonical_u8(*byte)));
154            }
155            challenger.observe_slice(builder, inputs);
156        }
157
158        let shard_chips = self
159            .machine
160            .chips()
161            .iter()
162            .filter(|chip| heights.contains_key(chip.name()))
163            .cloned()
164            .collect::<BTreeSet<_>>();
165
166        let degrees = opened_values.chips.values().map(|x| x.degree.clone()).collect::<Vec<_>>();
167
168        let max_log_row_count = self.pcs_verifier.max_log_row_count;
169
170        // Verify the `LogUp` GKR proof.
171        builder.cycle_tracker_v2_enter("verify-logup-gkr");
172        RecursiveLogUpGkrVerifier::<C, GC, A>::verify_logup_gkr(
173            builder,
174            &shard_chips,
175            &degrees,
176            max_log_row_count,
177            logup_gkr_proof,
178            public_values,
179            challenger,
180        );
181        builder.cycle_tracker_v2_exit();
182
183        // Verify the zerocheck proof.
184        builder.cycle_tracker_v2_enter("verify-zerocheck");
185        self.verify_zerocheck(
186            builder,
187            &shard_chips,
188            opened_values,
189            &logup_gkr_proof.logup_evaluations,
190            zerocheck_proof,
191            public_values,
192            challenger,
193        );
194        builder.cycle_tracker_v2_exit();
195
196        // Verify the opening proof.
197        let (preprocessed_openings_for_proof, main_openings_for_proof): (Vec<_>, Vec<_>) = proof
198            .opened_values
199            .chips
200            .values()
201            .map(|opening| (opening.preprocessed.clone(), opening.main.clone()))
202            .unzip();
203
204        let preprocessed_openings = preprocessed_openings_for_proof
205            .iter()
206            .map(|x| x.local.iter().as_slice())
207            .collect::<Vec<_>>();
208
209        let main_openings = main_openings_for_proof
210            .iter()
211            .map(|x| x.local.iter().copied().collect::<MleEval<_>>())
212            .collect::<Evaluations<_>>();
213
214        let filtered_preprocessed_openings = preprocessed_openings
215            .clone()
216            .into_iter()
217            .filter(|x| !x.is_empty())
218            .map(|x| x.iter().copied().collect::<MleEval<_>>())
219            .collect::<Evaluations<_>>();
220
221        let preprocessed_column_count = filtered_preprocessed_openings
222            .iter()
223            .map(|table_openings| table_openings.len())
224            .collect::<Vec<_>>();
225
226        let added_columns: Vec<usize> =
227            proof.evaluation_proof.column_counts.iter().map(|cc| cc[cc.len() - 2] + 1).collect();
228
229        let unfiltered_preprocessed_column_count = preprocessed_openings
230            .iter()
231            .map(|table_openings| table_openings.len())
232            .chain(std::iter::once(added_columns[0] - 1))
233            .collect::<Vec<_>>();
234
235        let main_column_count =
236            main_openings.iter().map(|table_openings| table_openings.len()).collect::<Vec<_>>();
237
238        let unfiltered_main_column_count = main_openings
239            .iter()
240            .map(|table_openings| table_openings.len())
241            .chain(std::iter::once(added_columns[1] - 1))
242            .collect::<Vec<_>>();
243
244        let (commitments, column_counts, unfiltered_column_counts, openings) = (
245            vec![vk.preprocessed_commit, *main_commitment],
246            vec![preprocessed_column_count, main_column_count.clone()],
247            vec![unfiltered_preprocessed_column_count, unfiltered_main_column_count],
248            Rounds { rounds: vec![filtered_preprocessed_openings, main_openings] },
249        );
250
251        let machine_jagged_verifier =
252            RecursiveMachineJaggedPcsVerifier::new(&self.pcs_verifier, column_counts.clone());
253
254        let openings = openings
255            .into_iter()
256            .map(|round| {
257                round
258                    .into_iter()
259                    .flat_map(std::iter::IntoIterator::into_iter)
260                    .collect::<MleEval<_>>()
261            })
262            .collect::<Vec<_>>();
263
264        builder.cycle_tracker_v2_enter("jagged-verifier");
265        let prefix_sum_felts = machine_jagged_verifier.verify_trusted_evaluations(
266            builder,
267            &commitments,
268            zerocheck_proof.point_and_eval.0.clone(),
269            &openings,
270            evaluation_proof,
271            challenger,
272        );
273        builder.cycle_tracker_v2_exit();
274
275        let row_count_felt: Felt<_> = builder
276            .constant(SP1Field::from_canonical_u32(1 << self.pcs_verifier.max_log_row_count));
277
278        let params: Vec<Vec<Felt<SP1Field>>> = unfiltered_column_counts
279            .iter()
280            .map(|round| {
281                round
282                    .iter()
283                    .copied()
284                    .zip(height_felts_map.values().copied().chain(std::iter::once(row_count_felt)))
285                    .flat_map(|(column_count, height)| {
286                        std::iter::repeat_n(height, column_count).collect::<Vec<_>>()
287                    })
288                    .collect::<Vec<_>>()
289            })
290            .collect();
291
292        let preprocessed_count = params[0].len();
293        let params = params.into_iter().flatten().collect::<Vec<_>>();
294
295        builder.cycle_tracker_v2_enter("jagged - prefix-sum-checks");
296        let mut param_index = 0;
297        // The prefix_sum_felts coming from the C::prefix_sum_checks call excludes what is the last
298        // element, namely the total area, in the Rust verifier. We add that check in manually
299        // below. That is why the Rust verifier `skip_indices` has two elements, while this
300        // one has one.
301        let skip_indices = [preprocessed_count];
302
303        prefix_sum_felts
304            .iter()
305            .zip(prefix_sum_felts.iter().skip(1))
306            .enumerate()
307            .filter(|(i, _)| !skip_indices.contains(i))
308            .for_each(|(_, (x, y))| {
309                let sum = *x + params[param_index];
310                builder.assert_felt_eq(sum, *y);
311                param_index += 1;
312            });
313
314        builder.assert_felt_eq(prefix_sum_felts[0], SP1Field::zero());
315
316        // Check that the preprocessed prefix sum is the correct multiple of `stacking_height`.
317        builder.assert_felt_eq(
318            prefix_sum_felts[skip_indices[0] + 1],
319            SP1Field::from_canonical_usize(
320                (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
321                    * evaluation_proof.pcs_proof.batch_evaluations.rounds[0].num_polynomials(),
322            ),
323        );
324
325        let preprocessed_padding_col_height =
326            builder.eval(prefix_sum_felts[skip_indices[0] + 1] - prefix_sum_felts[skip_indices[0]]);
327        let preprocessed_padding_col_bit_decomp = C::num2bits(
328            builder,
329            preprocessed_padding_col_height,
330            self.pcs_verifier.max_log_row_count + 1,
331        );
332
333        // We want to constrain the padding column to be in the range [0, 2^{max_log_row_count}].
334        // The above constraints ensure that the padding column is in the range [0,
335        // 2^{max_log_row_count+1}). The following constraints exclude the range
336        // (2^{max_log_row_count}, 2^{max_log_row_count+1}), namely by ensuring that if the
337        // the `max_log_row_count`-th bit is 1, then the less significant bits must be zero.
338        //
339        // NOTE: Strictly speaking, this is not necessary, since the jagged polynomial will
340        // force a zero evaluation in case any column height is greater than
341        // `2^{max_log_row_count}`, but we add this constraint for extra security, since it
342        // does not have a significant performance impact.
343        let max_bit = preprocessed_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
344        let max_bit = C::bits2num(builder, vec![max_bit]);
345        let zero: Felt<_> = builder.constant(SP1Field::zero());
346        for bit in
347            preprocessed_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count)
348        {
349            let bit_felt = C::bits2num(builder, vec![*bit]);
350            builder.assert_felt_eq(max_bit * bit_felt, zero);
351        }
352        let num_cols = prefix_sum_felts.len();
353
354        // Repeat the process above for the main trace padding column.
355        let main_padding_col_height =
356            builder.eval(prefix_sum_felts[num_cols - 1] - prefix_sum_felts[num_cols - 2]);
357
358        let main_padding_col_bit_decomp = C::num2bits(builder, main_padding_col_height, NUM_BITS);
359
360        let max_bit = main_padding_col_bit_decomp[self.pcs_verifier.max_log_row_count];
361        let max_bit = C::bits2num(builder, vec![max_bit]);
362        for bit in main_padding_col_bit_decomp.iter().skip(self.pcs_verifier.max_log_row_count + 1)
363        {
364            C::assert_bit_zero(builder, *bit);
365        }
366        for bit in main_padding_col_bit_decomp.iter().take(self.pcs_verifier.max_log_row_count) {
367            let bit_felt = C::bits2num(builder, vec![*bit]);
368            builder.assert_felt_eq(max_bit * bit_felt, zero);
369        }
370
371        // Compute the total area from the shape of the stacked PCS proof.
372        let total_area_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(
373            (1 << self.pcs_verifier.stacked_pcs_verifier.log_stacking_height)
374                * proof
375                    .evaluation_proof
376                    .pcs_proof
377                    .batch_evaluations
378                    .iter()
379                    .map(|evaluations| evaluations.num_polynomials())
380                    .sum::<usize>(),
381        ));
382
383        // Convert the final prefix sum to a symbolic felt.
384        let mut acc = SymbolicFelt::zero();
385        // Assert max height to avoid overflow during prefix-sum-checks.
386        proof.evaluation_proof.params.col_prefix_sums.iter().last().unwrap().iter().for_each(|x| {
387            acc = *x + two * acc;
388        });
389
390        // Check equality between the two above-computed values.
391        builder.assert_felt_eq(acc, total_area_felt);
392
393        builder.cycle_tracker_v2_exit();
394    }
395}
396
397pub type RecursiveVerifierPublicValuesConstraintFolder<'a> =
398    GenericVerifierPublicValuesConstraintFolder<
399        'a,
400        SP1Field,
401        SP1ExtensionField,
402        Felt<SP1Field>,
403        Ext<SP1Field, SP1ExtensionField>,
404        SymbolicExt<SP1Field, SP1ExtensionField>,
405    >;
406
407#[cfg(test)]
408mod tests {
409    use std::{marker::PhantomData, sync::Arc};
410
411    use slop_basefold::{BasefoldVerifier, FriConfig};
412    use sp1_core_executor::{Program, SP1Context, SP1CoreOpts};
413    use sp1_core_machine::{
414        io::SP1Stdin,
415        riscv::RiscvAir,
416        utils::{prove_core, setup_logger},
417    };
418    use sp1_hypercube::{
419        prover::{CpuShardProver, SP1InnerPcsProver, SimpleProver},
420        MachineVerifier, SP1InnerPcs, ShardVerifier, NUM_SP1_COMMITMENTS,
421    };
422    use sp1_recursion_compiler::{
423        circuit::{AsmCompiler, AsmConfig},
424        config::InnerConfig,
425    };
426    use sp1_recursion_machine::test::run_recursion_test_machines;
427
428    use crate::{
429        basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
430        challenger::DuplexChallengerVariable,
431        dummy::dummy_shard_proof,
432        jagged::RecursiveJaggedEvalSumcheckConfig,
433        witness::Witnessable,
434    };
435
436    use super::*;
437
438    use sp1_primitives::{SP1Field, SP1GlobalContext};
439    type GC = SP1GlobalContext;
440    type C = InnerConfig;
441    type A = RiscvAir<SP1Field>;
442
443    #[tokio::test]
444    async fn test_verify_shard() {
445        setup_logger();
446        let log_stacking_height = 21;
447        let max_log_row_count = 22;
448        let machine = RiscvAir::machine();
449        let verifier = ShardVerifier::from_basefold_parameters(
450            FriConfig::default_fri_config(),
451            log_stacking_height,
452            max_log_row_count,
453            machine.clone(),
454        );
455
456        let elf = test_artifacts::FIBONACCI_ELF;
457        let program = Arc::new(Program::from(&elf).unwrap());
458        let shard_prover =
459            CpuShardProver::<SP1GlobalContext, SP1InnerPcs, SP1InnerPcsProver, _>::new(
460                verifier.clone(),
461            );
462        let prover = SimpleProver::new(verifier.clone(), shard_prover);
463
464        let (pk, vk) = prover.setup(program.clone()).await;
465        let pk = unsafe { pk.into_inner() };
466        let (proof, _) = prove_core(
467            &prover,
468            pk,
469            program,
470            SP1Stdin::default(),
471            SP1CoreOpts::default(),
472            SP1Context::default(),
473        )
474        .await
475        .unwrap();
476
477        let mut builder = Builder::<C>::default();
478
479        // Get the vk and shard proof from the test artifacts.
480
481        let mut initial_challenger = verifier.jagged_pcs_verifier.challenger();
482        vk.observe_into(&mut initial_challenger);
483
484        let machine_verifier = MachineVerifier::new(verifier);
485        machine_verifier.verify(&vk, &proof).unwrap();
486
487        let shard_proof = proof.shard_proofs[0].clone();
488        let shape = machine_verifier.shape_from_proof(&shard_proof);
489
490        let dummy_proof = dummy_shard_proof(
491            shape.shard_chips,
492            max_log_row_count,
493            FriConfig::default_fri_config(),
494            log_stacking_height as usize,
495            &[
496                shape.preprocessed_area >> log_stacking_height,
497                shape.main_area >> log_stacking_height,
498            ],
499            &[shape.preprocessed_padding_cols, shape.main_padding_cols],
500        );
501
502        let vk_variable = vk.read(&mut builder);
503        let shard_proof_variable = dummy_proof.read(&mut builder);
504
505        let verifier =
506            BasefoldVerifier::<GC>::new(FriConfig::default_fri_config(), NUM_SP1_COMMITMENTS);
507        let recursive_verifier = crate::basefold::RecursiveBasefoldVerifier::<C, GC> {
508            fri_config: verifier.fri_config,
509            tcs: RecursiveMerkleTreeTcs::<C, GC>(PhantomData),
510        };
511        let recursive_verifier =
512            RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
513
514        let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<GC, C> {
515            stacked_pcs_verifier: recursive_verifier,
516            max_log_row_count,
517            jagged_evaluator: RecursiveJaggedEvalSumcheckConfig::<GC>(PhantomData),
518        };
519
520        let stark_verifier = RecursiveShardVerifier::<GC, A, C> {
521            machine,
522            pcs_verifier: recursive_jagged_verifier,
523            _phantom: std::marker::PhantomData,
524        };
525
526        let mut challenger_variable =
527            DuplexChallengerVariable::from_challenger(&mut builder, &initial_challenger);
528
529        builder.cycle_tracker_v2_enter("verify-shard");
530        stark_verifier.verify_shard(
531            &mut builder,
532            &vk_variable,
533            &shard_proof_variable,
534            &mut challenger_variable,
535        );
536        builder.cycle_tracker_v2_exit();
537
538        let block = builder.into_root_block();
539        let mut compiler = AsmCompiler::default();
540        let program = compiler.compile_inner(block).validate().unwrap();
541
542        let mut witness_stream = Vec::new();
543        Witnessable::<AsmConfig>::write(&vk, &mut witness_stream);
544        Witnessable::<AsmConfig>::write(&shard_proof, &mut witness_stream);
545
546        run_recursion_test_machines(program.clone(), witness_stream).await;
547    }
548}