Skip to main content

sp1_recursion_circuit/jagged/
jagged_eval.rs

1use std::marker::PhantomData;
2
3use rayon::ThreadPoolBuilder;
4use slop_jagged::{
5    BranchingProgram, JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof,
6};
7use slop_multilinear::{Mle, Point};
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10    circuit::CircuitV2Builder,
11    ir::{Builder, Ext, Felt, SymbolicExt, SymbolicFelt},
12};
13
14use crate::{
15    challenger::FieldChallengerVariable, sumcheck::verify_sumcheck, symbolic::IntoSymbolic,
16    CircuitConfig, SP1FieldConfigVariable,
17};
18
19impl<C: CircuitConfig> IntoSymbolic<C> for JaggedLittlePolynomialVerifierParams<Felt<SP1Field>> {
20    type Output = JaggedLittlePolynomialVerifierParams<SymbolicFelt<SP1Field>>;
21
22    fn as_symbolic(&self) -> Self::Output {
23        JaggedLittlePolynomialVerifierParams {
24            col_prefix_sums: self
25                .col_prefix_sums
26                .iter()
27                .map(|x| <Point<Felt<SP1Field>> as IntoSymbolic<C>>::as_symbolic(x))
28                .collect::<Vec<_>>(),
29        }
30    }
31}
32
33pub trait RecursiveJaggedEvalConfig<C: CircuitConfig, Chal>: Sized {
34    type JaggedEvalProof;
35
36    #[allow(clippy::too_many_arguments)]
37    #[allow(clippy::type_complexity)]
38    fn jagged_evaluation(
39        &self,
40        builder: &mut Builder<C>,
41        params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
42        z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
43        z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
44        z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
45        proof: &Self::JaggedEvalProof,
46        challenger: &mut Chal,
47    ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>);
48}
49
50pub struct RecursiveTrivialJaggedEvalConfig;
51
52impl<C: CircuitConfig> RecursiveJaggedEvalConfig<C, ()> for RecursiveTrivialJaggedEvalConfig {
53    type JaggedEvalProof = ();
54
55    fn jagged_evaluation(
56        &self,
57        _builder: &mut Builder<C>,
58        params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
59        z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
60        z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
61        z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
62        _proof: &Self::JaggedEvalProof,
63        _challenger: &mut (),
64    ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>) {
65        let params_ef = JaggedLittlePolynomialVerifierParams {
66            col_prefix_sums: params
67                .col_prefix_sums
68                .iter()
69                .map(|x| x.iter().map(|y| SymbolicExt::from(*y)).collect())
70                .collect::<Vec<_>>(),
71        };
72        let z_row =
73            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_row);
74        let z_col =
75            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_col);
76        let z_trace =
77            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_trace);
78
79        // Need to use a single threaded rayon pool.
80        let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
81        let result = pool.install(|| {
82            params_ef.full_jagged_little_polynomial_evaluation(&z_row, &z_col, &z_trace)
83        });
84        (result, vec![])
85    }
86}
87
88#[derive(Debug, Clone)]
89pub struct RecursiveJaggedEvalSumcheckConfig<SC>(pub PhantomData<SC>);
90
91impl<C: CircuitConfig, SC: SP1FieldConfigVariable<C>>
92    RecursiveJaggedEvalConfig<C, SC::FriChallengerVariable>
93    for RecursiveJaggedEvalSumcheckConfig<SC>
94{
95    type JaggedEvalProof = JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>;
96
97    fn jagged_evaluation(
98        &self,
99        builder: &mut Builder<C>,
100        params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
101        z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
102        z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
103        z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
104        proof: &Self::JaggedEvalProof,
105        challenger: &mut SC::FriChallengerVariable,
106    ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>) {
107        let z_row =
108            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_row);
109        let z_col =
110            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_col);
111        let z_trace =
112            <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_trace);
113
114        let JaggedSumcheckEvalProof { partial_sumcheck_proof } = proof;
115        // Calculate the partial lagrange from z_col point.
116        let z_col_partial_lagrange = Mle::blocking_partial_lagrange(&z_col);
117        let z_col_partial_lagrange = z_col_partial_lagrange.guts().as_slice();
118
119        // Calculate the jagged eval from the branching program eval claims.
120        let jagged_eval = partial_sumcheck_proof.claimed_sum;
121
122        challenger.observe_ext_element(builder, jagged_eval);
123
124        builder.assert_ext_eq(jagged_eval, partial_sumcheck_proof.claimed_sum);
125
126        // Verify the jagged eval proof.
127        builder.cycle_tracker_v2_enter("jagged eval - verify sumcheck");
128        verify_sumcheck::<C, SC>(builder, challenger, partial_sumcheck_proof);
129        builder.cycle_tracker_v2_exit();
130        let proof_point = <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(
131            &partial_sumcheck_proof.point_and_eval.0,
132        );
133        let (first_half_z_index, second_half_z_index) =
134            proof_point.split_at(proof_point.dimension() / 2);
135        assert!(first_half_z_index.len() == second_half_z_index.len());
136
137        // Compute the jagged eval sc expected eval and assert it matches the proof's eval.
138        let current_column_prefix_sums = params.col_prefix_sums.iter();
139        let next_column_prefix_sums = params.col_prefix_sums.iter().skip(1);
140        let mut prefix_sum_felts = Vec::new();
141        builder.cycle_tracker_v2_enter("jagged eval - calculate expected eval");
142        let mut jagged_eval_sc_expected_eval = current_column_prefix_sums
143            .zip(next_column_prefix_sums)
144            .zip(z_col_partial_lagrange.iter())
145            .map(|((current_column_prefix_sum, next_column_prefix_sum), z_col_eq_val)| {
146                assert!(current_column_prefix_sum.dimension() <= 30);
147                assert!(next_column_prefix_sum.dimension() <= 30);
148
149                let mut merged_prefix_sum = current_column_prefix_sum.clone();
150                merged_prefix_sum.extend(next_column_prefix_sum);
151
152                let (full_lagrange_eval, felt) = C::prefix_sum_checks(
153                    builder,
154                    merged_prefix_sum.to_vec(),
155                    partial_sumcheck_proof.point_and_eval.0.to_vec(),
156                );
157                prefix_sum_felts.push(felt);
158                *z_col_eq_val * full_lagrange_eval
159            })
160            .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>();
161        builder.cycle_tracker_v2_exit();
162        let branching_program = BranchingProgram::new(z_row.clone(), z_trace.clone());
163        jagged_eval_sc_expected_eval *=
164            branching_program.eval(&first_half_z_index, &second_half_z_index);
165
166        builder
167            .assert_ext_eq(jagged_eval_sc_expected_eval, partial_sumcheck_proof.point_and_eval.1);
168
169        (jagged_eval.into(), prefix_sum_felts)
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use std::{marker::PhantomData, sync::Arc};
176
177    use rand::{thread_rng, Rng};
178    use slop_algebra::{extension::BinomialExtensionField, AbstractField};
179    use slop_alloc::CpuBackend;
180    use slop_challenger::{DuplexChallenger, IopCtx};
181    use slop_jagged::{
182        JaggedAssistSumAsPolyCPUImpl, JaggedEvalProver, JaggedEvalSumcheckProver,
183        JaggedLittlePolynomialProverParams, JaggedLittlePolynomialVerifierParams,
184    };
185    use slop_multilinear::Point;
186    use sp1_core_machine::utils::setup_logger;
187    use sp1_hypercube::{inner_perm, log2_ceil_usize};
188    use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
189    use sp1_recursion_compiler::{
190        circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder},
191        ir::{Ext, Felt},
192    };
193    use sp1_recursion_executor::Executor;
194
195    use crate::{
196        challenger::DuplexChallengerVariable,
197        jagged::jagged_eval::{
198            RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig,
199            RecursiveTrivialJaggedEvalConfig,
200        },
201        witness::Witnessable,
202        SP1FieldConfigVariable,
203    };
204
205    use sp1_primitives::{SP1Field, SP1Perm};
206    type F = SP1Field;
207    type EF = BinomialExtensionField<SP1Field, 4>;
208    type C = AsmConfig;
209    type SC = SP1GlobalContext;
210
211    fn trivial_jagged_eval(
212        verifier_params: &JaggedLittlePolynomialVerifierParams<F>,
213        z_row: &Point<EF>,
214        z_col: &Point<EF>,
215        z_trace: &Point<EF>,
216        expected_result: EF,
217        should_succeed: bool,
218    ) {
219        let mut builder = AsmBuilder::default();
220        builder.cycle_tracker_v2_enter("trivial-jagged-eval");
221        let verifier_params_variable = verifier_params.read(&mut builder);
222        let z_row_variable = z_row.read(&mut builder);
223        let z_col_variable = z_col.read(&mut builder);
224        let z_trace_variable = z_trace.read(&mut builder);
225        let recursive_jagged_evaluator = RecursiveTrivialJaggedEvalConfig {};
226        let (recursive_jagged_evaluation, _) = <RecursiveTrivialJaggedEvalConfig as RecursiveJaggedEvalConfig<C, ()>>::jagged_evaluation(
227            &recursive_jagged_evaluator,
228            &mut builder,
229            &verifier_params_variable,
230            z_row_variable,
231            z_col_variable,
232            z_trace_variable,
233            &(),
234            &mut (),
235        );
236        let recursive_jagged_evaluation: Ext<F, EF> = builder.eval(recursive_jagged_evaluation);
237        let expected_result: Ext<F, EF> = builder.constant(expected_result);
238        builder.assert_ext_eq(recursive_jagged_evaluation, expected_result);
239        builder.cycle_tracker_v2_exit();
240
241        let block = builder.into_root_block();
242        let mut compiler = AsmCompiler::default();
243        let program = compiler.compile_inner(block).validate().unwrap();
244
245        let mut witness_stream = Vec::new();
246        Witnessable::<AsmConfig>::write(&verifier_params, &mut witness_stream);
247        Witnessable::<AsmConfig>::write(&z_row, &mut witness_stream);
248        Witnessable::<AsmConfig>::write(&z_col, &mut witness_stream);
249        Witnessable::<AsmConfig>::write(&z_trace, &mut witness_stream);
250
251        let mut executor =
252            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
253        executor.witness_stream = witness_stream.into();
254        if should_succeed {
255            executor.run().unwrap();
256        } else {
257            executor.run().expect_err("invalid proof should not be verified");
258        }
259    }
260
261    fn sumcheck_jagged_eval(
262        prover_params: &JaggedLittlePolynomialProverParams,
263        verifier_params: &JaggedLittlePolynomialVerifierParams<F>,
264        z_row: &Point<EF>,
265        z_col: &Point<EF>,
266        z_trace: &Point<EF>,
267        expected_result: EF,
268        should_succeed: bool,
269    ) -> Vec<Felt<F>> {
270        let prover = JaggedEvalSumcheckProver::<
271            F,
272            JaggedAssistSumAsPolyCPUImpl<_, _, _>,
273            CpuBackend,
274            <SP1GlobalContext as IopCtx>::Challenger,
275        >::default();
276        let default_perm = inner_perm();
277        let mut challenger =
278            DuplexChallenger::<SP1Field, SP1Perm, 16, 8>::new(default_perm.clone());
279        let jagged_eval_proof = prover.prove_jagged_evaluation(
280            prover_params,
281            z_row,
282            z_col,
283            z_trace,
284            &mut challenger,
285            CpuBackend,
286        );
287
288        let mut builder = AsmBuilder::default();
289        builder.cycle_tracker_v2_enter("sumcheck-jagged-eval");
290        let verifier_params_variable = verifier_params.read(&mut builder);
291        let z_row_variable = z_row.read(&mut builder);
292        let z_col_variable = z_col.read(&mut builder);
293        let z_trace_variable = z_trace.read(&mut builder);
294        let jagged_eval_proof_variable = jagged_eval_proof.read(&mut builder);
295        let recursive_jagged_evaluator = RecursiveJaggedEvalSumcheckConfig::<SC>(PhantomData);
296        let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
297        let (recursive_jagged_evaluation, prefix_sum_felts) =
298            <RecursiveJaggedEvalSumcheckConfig<SC> as RecursiveJaggedEvalConfig<
299                C,
300                <SC as SP1FieldConfigVariable<C>>::FriChallengerVariable,
301            >>::jagged_evaluation(
302                &recursive_jagged_evaluator,
303                &mut builder,
304                &verifier_params_variable,
305                z_row_variable,
306                z_col_variable,
307                z_trace_variable,
308                &jagged_eval_proof_variable,
309                &mut challenger_variable,
310            );
311        let recursive_jagged_evaluation: Ext<F, EF> = builder.eval(recursive_jagged_evaluation);
312        let expected_result: Ext<F, EF> = builder.constant(expected_result);
313        builder.assert_ext_eq(recursive_jagged_evaluation, expected_result);
314        builder.cycle_tracker_v2_exit();
315
316        let block = builder.into_root_block();
317        let mut compiler = AsmCompiler::default();
318        let program = compiler.compile_inner(block).validate().unwrap();
319
320        let mut witness_stream = Vec::new();
321        Witnessable::<AsmConfig>::write(&verifier_params, &mut witness_stream);
322        Witnessable::<AsmConfig>::write(&z_row, &mut witness_stream);
323        Witnessable::<AsmConfig>::write(&z_col, &mut witness_stream);
324        Witnessable::<AsmConfig>::write(&z_trace, &mut witness_stream);
325        Witnessable::<AsmConfig>::write(&jagged_eval_proof, &mut witness_stream);
326        let mut executor =
327            Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
328        executor.witness_stream = witness_stream.into();
329        if should_succeed {
330            executor.run().unwrap();
331        } else {
332            executor.run().expect_err("invalid proof should not be verified");
333        }
334        prefix_sum_felts
335    }
336
337    #[test]
338    fn test_jagged_eval_proof() {
339        setup_logger();
340        let row_counts = [12, 1, 2, 1, 17, 0];
341
342        let mut prefix_sums = row_counts
343            .iter()
344            .scan(0, |state, row_count| {
345                let result = *state;
346                *state += row_count;
347                Some(result)
348            })
349            .collect::<Vec<_>>();
350        prefix_sums.push(*prefix_sums.last().unwrap() + row_counts.last().unwrap());
351
352        let mut rng = thread_rng();
353
354        let log_m = log2_ceil_usize(*prefix_sums.last().unwrap());
355
356        let log_max_row_count = 7;
357
358        let prover_params =
359            JaggedLittlePolynomialProverParams::new(row_counts.to_vec(), log_max_row_count);
360
361        let verifier_params: JaggedLittlePolynomialVerifierParams<F> =
362            prover_params.clone().into_verifier_params();
363
364        let z_row: Point<EF> = (0..log_max_row_count).map(|_| rng.gen::<EF>()).collect();
365        let z_col: Point<EF> =
366            (0..log2_ceil_usize(row_counts.len())).map(|_| rng.gen::<EF>()).collect();
367        let z_trace: Point<EF> = (0..log_m + 1).map(|_| rng.gen::<EF>()).collect();
368
369        let expected_result =
370            verifier_params.full_jagged_little_polynomial_evaluation(&z_row, &z_col, &z_trace);
371
372        trivial_jagged_eval(&verifier_params, &z_row, &z_col, &z_trace, expected_result, true);
373        sumcheck_jagged_eval(
374            &prover_params,
375            &verifier_params,
376            &z_row,
377            &z_col,
378            &z_trace,
379            expected_result,
380            true,
381        );
382
383        // Test the invalid cases.
384        let mut z_row_invalid = z_row.clone();
385        let first_element = z_row_invalid.get_mut(0).unwrap();
386        *first_element += EF::one();
387        trivial_jagged_eval(
388            &verifier_params,
389            &z_row_invalid,
390            &z_col,
391            &z_trace,
392            expected_result,
393            false,
394        );
395        sumcheck_jagged_eval(
396            &prover_params,
397            &verifier_params,
398            &z_row_invalid,
399            &z_col,
400            &z_trace,
401            expected_result,
402            false,
403        );
404    }
405}