Skip to main content

sp1_recursion_circuit/
zerocheck.rs

1use std::{collections::BTreeSet, ops::Deref};
2
3use crate::{
4    challenger::{CanObserveVariable, FieldChallengerVariable},
5    shard::RecursiveShardVerifier,
6    sumcheck::verify_sumcheck,
7    symbolic::IntoSymbolic,
8    CircuitConfig, SP1FieldConfigVariable,
9};
10use itertools::Itertools;
11use slop_air::{Air, BaseAir};
12use slop_algebra::AbstractField;
13use slop_challenger::IopCtx;
14use slop_matrix::dense::RowMajorMatrixView;
15use slop_multilinear::{full_geq, Mle, Point};
16use slop_sumcheck::PartialSumcheckProof;
17use sp1_hypercube::{
18    air::MachineAir, Chip, ChipOpenedValues, GenericVerifierConstraintFolder, LogUpEvaluations,
19    OpeningShapeError, ShardOpenedValues,
20};
21use sp1_primitives::{SP1ExtensionField, SP1Field};
22use sp1_recursion_compiler::{
23    ir::Felt,
24    prelude::{Builder, Ext, SymbolicExt},
25};
26
27pub type RecursiveVerifierConstraintFolder<'a> = GenericVerifierConstraintFolder<
28    'a,
29    SP1Field,
30    SP1ExtensionField,
31    Felt<SP1Field>,
32    Ext<SP1Field, SP1ExtensionField>,
33    SymbolicExt<SP1Field, SP1ExtensionField>,
34>;
35
36#[allow(clippy::type_complexity)]
37pub fn eval_constraints<C: CircuitConfig, SC: SP1FieldConfigVariable<C>, A>(
38    builder: &mut Builder<C>,
39    chip: &Chip<SP1Field, A>,
40    opening: &ChipOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
41    alpha: Ext<SP1Field, SP1ExtensionField>,
42    public_values: &[Felt<SP1Field>],
43) -> Ext<SP1Field, SP1ExtensionField>
44where
45    A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
46{
47    let mut folder = RecursiveVerifierConstraintFolder {
48        preprocessed: RowMajorMatrixView::new_row(&opening.preprocessed.local),
49        main: RowMajorMatrixView::new_row(&opening.main.local),
50        public_values,
51        alpha,
52        accumulator: SymbolicExt::zero(),
53        _marker: std::marker::PhantomData,
54    };
55
56    chip.eval(&mut folder);
57    builder.eval(folder.accumulator)
58}
59
60/// Compute the padded row adjustment for a chip.
61pub fn compute_padded_row_adjustment<C: CircuitConfig, A>(
62    builder: &mut Builder<C>,
63    chip: &Chip<SP1Field, A>,
64    alpha: Ext<SP1Field, SP1ExtensionField>,
65    public_values: &[Felt<SP1Field>],
66) -> Ext<SP1Field, SP1ExtensionField>
67where
68    A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
69{
70    let zero = builder.constant(SP1ExtensionField::zero());
71    let dummy_preprocessed_trace = vec![zero; chip.preprocessed_width()];
72    let dummy_main_trace = vec![zero; chip.width()];
73
74    let mut folder = RecursiveVerifierConstraintFolder {
75        preprocessed: RowMajorMatrixView::new_row(&dummy_preprocessed_trace),
76        main: RowMajorMatrixView::new_row(&dummy_main_trace),
77        alpha,
78        accumulator: SymbolicExt::zero(),
79        public_values,
80        _marker: std::marker::PhantomData,
81    };
82
83    chip.eval(&mut folder);
84    builder.eval(folder.accumulator)
85}
86
87#[allow(clippy::type_complexity)]
88pub fn verify_opening_shape<C: CircuitConfig, A>(
89    chip: &Chip<SP1Field, A>,
90    opening: &ChipOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
91) -> Result<(), OpeningShapeError>
92where
93    A: MachineAir<SP1Field> + for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
94{
95    // Verify that the preprocessed width matches the expected value for the chip.
96    if opening.preprocessed.local.len() != chip.preprocessed_width() {
97        return Err(OpeningShapeError::PreprocessedWidthMismatch(
98            chip.preprocessed_width(),
99            opening.preprocessed.local.len(),
100        ));
101    }
102
103    // Verify that the main width matches the expected value for the chip.
104    if opening.main.local.len() != chip.width() {
105        return Err(OpeningShapeError::MainWidthMismatch(chip.width(), opening.main.local.len()));
106    }
107
108    Ok(())
109}
110
111impl<GC, C, A> RecursiveShardVerifier<GC, A, C>
112where
113    GC: IopCtx<F = SP1Field, EF = SP1ExtensionField> + SP1FieldConfigVariable<C>,
114    C: CircuitConfig,
115    A: MachineAir<SP1Field>,
116{
117    #[allow(clippy::too_many_arguments)]
118    #[allow(clippy::type_complexity)]
119    pub fn verify_zerocheck(
120        &self,
121        builder: &mut Builder<C>,
122        shard_chips: &BTreeSet<Chip<SP1Field, A>>,
123        opened_values: &ShardOpenedValues<Felt<SP1Field>, Ext<SP1Field, SP1ExtensionField>>,
124        gkr_evaluations: &LogUpEvaluations<Ext<SP1Field, SP1ExtensionField>>,
125        zerocheck_proof: &PartialSumcheckProof<Ext<SP1Field, SP1ExtensionField>>,
126        public_values: &[Felt<SP1Field>],
127        challenger: &mut GC::FriChallengerVariable,
128    ) where
129        A: for<'a> Air<RecursiveVerifierConstraintFolder<'a>>,
130    {
131        let zero: Ext<SP1Field, SP1ExtensionField> = builder.constant(SP1ExtensionField::zero());
132        let one: Ext<SP1Field, SP1ExtensionField> = builder.constant(SP1ExtensionField::one());
133        let mut rlc_eval: Ext<SP1Field, SP1ExtensionField> = zero;
134
135        let alpha = challenger.sample_ext(builder);
136        let gkr_batch_open_challenge: SymbolicExt<SP1Field, SP1ExtensionField> =
137            challenger.sample_ext(builder).into();
138        let lambda = challenger.sample_ext(builder);
139
140        // Get the value of eq(zeta, sumcheck's reduced point).
141        let point_symbolic =
142            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(
143                &zerocheck_proof.point_and_eval.0,
144            );
145
146        let gkr_evaluations_point = IntoSymbolic::<C>::as_symbolic(&gkr_evaluations.point);
147
148        let zerocheck_eq_val = Mle::full_lagrange_eval(&gkr_evaluations_point, &point_symbolic);
149
150        let max_elements = shard_chips
151            .iter()
152            .map(|chip| chip.width() + chip.preprocessed_width())
153            .max()
154            .unwrap_or(0);
155
156        let gkr_batch_open_challenge_powers =
157            gkr_batch_open_challenge.powers().skip(1).take(max_elements).collect::<Vec<_>>();
158
159        for (chip, openings) in shard_chips.iter().zip_eq(opened_values.chips.values()) {
160            // Verify the shape of the opening arguments matches the expected values.
161            verify_opening_shape::<C, A>(chip, openings).unwrap();
162
163            let dimension = zerocheck_proof.point_and_eval.0.dimension();
164
165            assert_eq!(dimension, self.pcs_verifier.max_log_row_count);
166
167            let mut proof_point_extended = point_symbolic.clone();
168            proof_point_extended.add_dimension(zero.into());
169            let degree_symbolic_ext: Point<SymbolicExt<SP1Field, SP1ExtensionField>> =
170                openings.degree.iter().map(|x| SymbolicExt::from(*x)).collect::<Point<_>>();
171            degree_symbolic_ext.iter().enumerate().for_each(|(i, x)| {
172                builder.assert_ext_eq(*x * (*x - one), zero);
173                if i >= 1 {
174                    builder.assert_ext_eq(*x * *degree_symbolic_ext.first().unwrap(), zero);
175                }
176            });
177            let geq_val = full_geq(&degree_symbolic_ext, &proof_point_extended);
178
179            let padded_row_adjustment =
180                compute_padded_row_adjustment(builder, chip, alpha, public_values);
181
182            let constraint_eval =
183                eval_constraints::<C, GC, A>(builder, chip, openings, alpha, public_values)
184                    - padded_row_adjustment * geq_val;
185
186            let openings_batch = openings
187                .main
188                .local
189                .iter()
190                .chain(openings.preprocessed.local.iter())
191                .copied()
192                .zip(
193                    gkr_batch_open_challenge_powers
194                        .iter()
195                        .take(openings.main.local.len() + openings.preprocessed.local.len())
196                        .copied(),
197                )
198                .map(|(opening, power)| opening * power)
199                .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>();
200
201            rlc_eval = builder
202                .eval(rlc_eval * lambda + zerocheck_eq_val * (constraint_eval + openings_batch));
203        }
204
205        builder.assert_ext_eq(rlc_eval, zerocheck_proof.point_and_eval.1);
206
207        let zerocheck_sum_modifications_from_gkr = gkr_evaluations
208            .chip_openings
209            .values()
210            .map(|chip_evaluation| {
211                chip_evaluation
212                    .main_trace_evaluations
213                    .deref()
214                    .iter()
215                    .copied()
216                    .chain(
217                        chip_evaluation
218                            .preprocessed_trace_evaluations
219                            .as_ref()
220                            .iter()
221                            .flat_map(|&evals| evals.deref().iter().copied()),
222                    )
223                    .zip(gkr_batch_open_challenge_powers.iter().copied())
224                    .map(|(opening, power)| opening * power)
225                    .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>()
226            })
227            .collect::<Vec<_>>();
228
229        let zerocheck_sum_modification: SymbolicExt<SP1Field, SP1ExtensionField> =
230            zerocheck_sum_modifications_from_gkr
231                .iter()
232                .fold(zero.into(), |acc, modification| lambda * acc + *modification);
233
234        // Verify that the rlc claim is zero.
235        builder.assert_ext_eq(zerocheck_proof.claimed_sum, zerocheck_sum_modification);
236
237        // Verify the zerocheck proof.
238        verify_sumcheck::<C, GC>(builder, challenger, zerocheck_proof);
239
240        // Observe the openings
241        let len_felt: Felt<_> = builder.constant(SP1Field::from_canonical_usize(shard_chips.len()));
242        challenger.observe(builder, len_felt);
243        for opening in opened_values.chips.values() {
244            challenger
245                .observe_variable_length_extension_slice(builder, &opening.preprocessed.local);
246            challenger.observe_variable_length_extension_slice(builder, &opening.main.local);
247        }
248    }
249}
250
251// TODO: Add tests back.
252// #[cfg(test)]
253// mod tests {
254//     use std::{marker::PhantomData, sync::Arc};
255
256//     use slop_algebra::extension::BinomialExtensionField;
257//     use sp1_primitives::SP1DiffusionMatrix;
258//     use slop_basefold::{BasefoldVerifier, SP1BasefoldConfig};
259//     use slop_jagged::SP1InnerPcs;
260//     use sp1_hypercube::inner_perm;
261//     use sp1_core_executor::{Program, SP1Context};
262//     use sp1_core_machine::{io::SP1Stdin, riscv::RiscvAir, utils::prove_core};
263//     use sp1_recursion_compiler::{
264//         circuit::{AsmCompiler, AsmConfig},
265//         config::InnerConfig,
266//     };
267//     use sp1_recursion_executor::Runtime;
268//     use sp1_hypercube::{prover::CpuProver, SP1CoreOpts, ShardVerifier};
269
270//     use crate::{
271//         basefold::{stacked::RecursiveStackedPcsVerifier, tcs::RecursiveMerkleTreeTcs},
272//         challenger::DuplexChallengerVariable,
273//         jagged::{
274//             RecursiveJaggedConfigImpl, RecursiveJaggedEvalSumcheckConfig,
275//             RecursiveJaggedPcsVerifier,
276//         },
277//         witness::Witnessable,
278//     };
279
280//     use super::*;
281
282//     use sp1_primitives::SP1Field;
283//    type F = SP1Field;
284//     type SC = SP1InnerPcs;
285//     type JC = RecursiveJaggedConfigImpl<
286//         C,
287//         SC,
288//         RecursiveBasefoldVerifier<RecursiveBasefoldConfigImpl<C, SC>>,
289//     >;
290//     type C = InnerConfig;
291//     type EF = BinomialExtensionField<SP1Field, 4>;
292//     type A = RiscvAir<SP1Field>;
293
294//     #[tokio::test]
295//     async fn test_zerocheck() {
296//         let program = Program::from(test_artifacts::FIBONACCI_ELF).unwrap();
297//         let log_blowup = 1;
298//         let log_stacking_height = 21;
299//         let max_log_row_count = 21;
300//         let machine = RiscvAir::machine();
301//         let verifier = ShardVerifier::from_basefold_parameters(
302//             log_blowup,
303//             log_stacking_height,
304//             max_log_row_count,
305//             machine.clone(),
306//         );
307//         let prover = CpuProver::new(verifier.clone());
308
309//         let (pk, _) = prover.setup(Arc::new(program.clone())).await;
310
311//         let challenger = verifier.pcs_verifier.challenger();
312
313//         let (proof, _) = prove_core(
314//             Arc::new(prover),
315//             Arc::new(pk),
316//             Arc::new(program.clone()),
317//             &SP1Stdin::new(),
318//             SP1CoreOpts::default(),
319//             SP1Context::default(),
320//             challenger,
321//         )
322//         .await
323//         .unwrap();
324
325//         let shard_proof = proof.shard_proofs[0].clone();
326//         let challenger_state = shard_proof.testing_data.challenger_state.clone();
327
328//         let mut builder = Builder::<C>::default();
329
330//         let mut challenger_variable =
331//             DuplexChallengerVariable::from_challenger(&mut builder, &challenger_state);
332
333//         let shard_proof_variable = shard_proof.read(&mut builder);
334
335//         let gkr_points_variable = shard_proof.testing_data.gkr_points.read(&mut builder);
336//         let gkr_column_openings_variable = shard_proof
337//             .gkr_proofs
338//             .iter()
339//             .map(|gkr_proof| {
340//                 let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
341//                 let main_openings_variable = main_openings.read(&mut builder);
342//                 let preprocessed_openings_variable: MleEval<Ext<_, _>> = preprocessed_openings
343//                     .as_ref()
344//                     .map(MleEval::to_vec)
345//                     .unwrap_or_default()
346//                     .read(&mut builder)
347//                     .into();
348//                 (main_openings_variable, preprocessed_openings_variable)
349//             })
350//             .collect::<Vec<_>>();
351
352//         let verifier = BasefoldVerifier::<SP1BasefoldConfig>::new(log_blowup);
353//         let recursive_verifier = RecursiveBasefoldVerifier::<RecursiveBasefoldConfigImpl<C, SC>>
354// {             fri_config: verifier.fri_config,
355//             tcs: RecursiveMerkleTreeTcs::<C, SC>(PhantomData),
356//         };
357//         let recursive_verifier =
358//             RecursiveStackedPcsVerifier::new(recursive_verifier, log_stacking_height);
359
360//         let recursive_jagged_verifier = RecursiveJaggedPcsVerifier::<
361//             SC,
362//             C,
363//             RecursiveJaggedConfigImpl<
364//                 C,
365//                 SC,
366//                 RecursiveBasefoldVerifier<RecursiveBasefoldConfigImpl<C, SC>>,
367//             >,
368//         > { stacked_pcs_verifier: recursive_verifier, max_log_row_count, jagged_evaluator:
369//         > RecursiveJaggedEvalSumcheckConfig::<SP1InnerPcs>(PhantomData),
370//         };
371
372//         let stark_verifier = StarkVerifier::<A, SC, C, JC> {
373//             machine,
374//             pcs_verifier: recursive_jagged_verifier,
375//             _phantom: std::marker::PhantomData,
376//         };
377
378//         stark_verifier.verify_zerocheck(
379//             &mut builder,
380//             &mut challenger_variable,
381//             &shard_proof_variable.opened_values,
382//             &shard_proof_variable.zerocheck_proof,
383//             &gkr_points_variable,
384//             &gkr_column_openings_variable,
385//             &shard_proof_variable.public_values,
386//         );
387
388//         let mut witness_stream = Vec::new();
389//         Witnessable::<AsmConfig<F, EF>>::write(&shard_proof, &mut witness_stream);
390//         Witnessable::<AsmConfig<F, EF>>::write(
391//             &shard_proof.testing_data.gkr_points,
392//             &mut witness_stream,
393//         );
394//         shard_proof.gkr_proofs.iter().for_each(|gkr_proof| {
395//             let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
396//             Witnessable::<AsmConfig<F, EF>>::write(main_openings, &mut witness_stream);
397//             let preprocessed_openings_unwrapped: MleEval<_> =
398//                 preprocessed_openings.as_ref().map(MleEval::to_vec).unwrap_or_default().into();
399//             Witnessable::<AsmConfig<F, EF>>::write(
400//                 &preprocessed_openings_unwrapped,
401//                 &mut witness_stream,
402//             );
403//         });
404
405//         let block = builder.into_root_block();
406//         let mut compiler = AsmCompiler::<AsmConfig<F, EF>>::default();
407//         let program = Arc::new(compiler.compile_inner(block).validate().unwrap());
408//         let mut executor =
409//             Runtime::<F, EF, SP1DiffusionMatrix>::new(program.clone(), inner_perm());
410//         executor.witness_stream = witness_stream.into();
411//         executor.run().unwrap();
412
413//         // Test for a bad zerocheck proof.
414//         let mut invalid_shard_proof = shard_proof.clone();
415//         invalid_shard_proof.zerocheck_proof.univariate_polys[0].coefficients[0] += EF::one();
416//         let mut witness_stream = Vec::new();
417//         Witnessable::<AsmConfig<F, EF>>::write(&invalid_shard_proof, &mut witness_stream);
418//         Witnessable::<AsmConfig<F, EF>>::write(
419//             &invalid_shard_proof.testing_data.gkr_points,
420//             &mut witness_stream,
421//         );
422//         invalid_shard_proof.gkr_proofs.iter().for_each(|gkr_proof| {
423//             let (main_openings, preprocessed_openings) = &gkr_proof.column_openings;
424//             Witnessable::<AsmConfig<F, EF>>::write(main_openings, &mut witness_stream);
425//             let preprocessed_openings_unwrapped: MleEval<_> =
426//                 preprocessed_openings.as_ref().map(MleEval::to_vec).unwrap_or_default().into();
427//             Witnessable::<AsmConfig<F, EF>>::write(
428//                 &preprocessed_openings_unwrapped,
429//                 &mut witness_stream,
430//             );
431//         });
432//         let mut executor = Runtime::<F, EF, SP1DiffusionMatrix>::new(program,
433// inner_perm());         executor.witness_stream = witness_stream.into();
434//         executor.run().expect_err("invalid proof should not be verified");
435//     }
436// }