1use std::marker::PhantomData;
2
3use rayon::ThreadPoolBuilder;
4use slop_jagged::{
5 BranchingProgram, JaggedLittlePolynomialVerifierParams, JaggedSumcheckEvalProof,
6};
7use slop_multilinear::{Mle, Point};
8use sp1_primitives::{SP1ExtensionField, SP1Field};
9use sp1_recursion_compiler::{
10 circuit::CircuitV2Builder,
11 ir::{Builder, Ext, Felt, SymbolicExt, SymbolicFelt},
12};
13
14use crate::{
15 challenger::FieldChallengerVariable, sumcheck::verify_sumcheck, symbolic::IntoSymbolic,
16 CircuitConfig, SP1FieldConfigVariable,
17};
18
19impl<C: CircuitConfig> IntoSymbolic<C> for JaggedLittlePolynomialVerifierParams<Felt<SP1Field>> {
20 type Output = JaggedLittlePolynomialVerifierParams<SymbolicFelt<SP1Field>>;
21
22 fn as_symbolic(&self) -> Self::Output {
23 JaggedLittlePolynomialVerifierParams {
24 col_prefix_sums: self
25 .col_prefix_sums
26 .iter()
27 .map(|x| <Point<Felt<SP1Field>> as IntoSymbolic<C>>::as_symbolic(x))
28 .collect::<Vec<_>>(),
29 }
30 }
31}
32
33pub trait RecursiveJaggedEvalConfig<C: CircuitConfig, Chal>: Sized {
34 type JaggedEvalProof;
35
36 #[allow(clippy::too_many_arguments)]
37 #[allow(clippy::type_complexity)]
38 fn jagged_evaluation(
39 &self,
40 builder: &mut Builder<C>,
41 params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
42 z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
43 z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
44 z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
45 proof: &Self::JaggedEvalProof,
46 challenger: &mut Chal,
47 ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>);
48}
49
50pub struct RecursiveTrivialJaggedEvalConfig;
51
52impl<C: CircuitConfig> RecursiveJaggedEvalConfig<C, ()> for RecursiveTrivialJaggedEvalConfig {
53 type JaggedEvalProof = ();
54
55 fn jagged_evaluation(
56 &self,
57 _builder: &mut Builder<C>,
58 params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
59 z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
60 z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
61 z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
62 _proof: &Self::JaggedEvalProof,
63 _challenger: &mut (),
64 ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>) {
65 let params_ef = JaggedLittlePolynomialVerifierParams {
66 col_prefix_sums: params
67 .col_prefix_sums
68 .iter()
69 .map(|x| x.iter().map(|y| SymbolicExt::from(*y)).collect())
70 .collect::<Vec<_>>(),
71 };
72 let z_row =
73 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_row);
74 let z_col =
75 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_col);
76 let z_trace =
77 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_trace);
78
79 let pool = ThreadPoolBuilder::new().num_threads(1).build().unwrap();
81 let result = pool.install(|| {
82 params_ef.full_jagged_little_polynomial_evaluation(&z_row, &z_col, &z_trace)
83 });
84 (result, vec![])
85 }
86}
87
88#[derive(Debug, Clone)]
89pub struct RecursiveJaggedEvalSumcheckConfig<SC>(pub PhantomData<SC>);
90
91impl<C: CircuitConfig, SC: SP1FieldConfigVariable<C>>
92 RecursiveJaggedEvalConfig<C, SC::FriChallengerVariable>
93 for RecursiveJaggedEvalSumcheckConfig<SC>
94{
95 type JaggedEvalProof = JaggedSumcheckEvalProof<Ext<SP1Field, SP1ExtensionField>>;
96
97 fn jagged_evaluation(
98 &self,
99 builder: &mut Builder<C>,
100 params: &JaggedLittlePolynomialVerifierParams<Felt<SP1Field>>,
101 z_row: Point<Ext<SP1Field, SP1ExtensionField>>,
102 z_col: Point<Ext<SP1Field, SP1ExtensionField>>,
103 z_trace: Point<Ext<SP1Field, SP1ExtensionField>>,
104 proof: &Self::JaggedEvalProof,
105 challenger: &mut SC::FriChallengerVariable,
106 ) -> (SymbolicExt<SP1Field, SP1ExtensionField>, Vec<Felt<SP1Field>>) {
107 let z_row =
108 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_row);
109 let z_col =
110 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_col);
111 let z_trace =
112 <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(&z_trace);
113
114 let JaggedSumcheckEvalProof { partial_sumcheck_proof } = proof;
115 let z_col_partial_lagrange = Mle::blocking_partial_lagrange(&z_col);
117 let z_col_partial_lagrange = z_col_partial_lagrange.guts().as_slice();
118
119 let jagged_eval = partial_sumcheck_proof.claimed_sum;
121
122 challenger.observe_ext_element(builder, jagged_eval);
123
124 builder.assert_ext_eq(jagged_eval, partial_sumcheck_proof.claimed_sum);
125
126 builder.cycle_tracker_v2_enter("jagged eval - verify sumcheck");
128 verify_sumcheck::<C, SC>(builder, challenger, partial_sumcheck_proof);
129 builder.cycle_tracker_v2_exit();
130 let proof_point = <Point<Ext<SP1Field, SP1ExtensionField>> as IntoSymbolic<C>>::as_symbolic(
131 &partial_sumcheck_proof.point_and_eval.0,
132 );
133 let (first_half_z_index, second_half_z_index) =
134 proof_point.split_at(proof_point.dimension() / 2);
135 assert!(first_half_z_index.len() == second_half_z_index.len());
136
137 let current_column_prefix_sums = params.col_prefix_sums.iter();
139 let next_column_prefix_sums = params.col_prefix_sums.iter().skip(1);
140 let mut prefix_sum_felts = Vec::new();
141 builder.cycle_tracker_v2_enter("jagged eval - calculate expected eval");
142 let mut jagged_eval_sc_expected_eval = current_column_prefix_sums
143 .zip(next_column_prefix_sums)
144 .zip(z_col_partial_lagrange.iter())
145 .map(|((current_column_prefix_sum, next_column_prefix_sum), z_col_eq_val)| {
146 assert!(current_column_prefix_sum.dimension() <= 30);
147 assert!(next_column_prefix_sum.dimension() <= 30);
148
149 let mut merged_prefix_sum = current_column_prefix_sum.clone();
150 merged_prefix_sum.extend(next_column_prefix_sum);
151
152 let (full_lagrange_eval, felt) = C::prefix_sum_checks(
153 builder,
154 merged_prefix_sum.to_vec(),
155 partial_sumcheck_proof.point_and_eval.0.to_vec(),
156 );
157 prefix_sum_felts.push(felt);
158 *z_col_eq_val * full_lagrange_eval
159 })
160 .sum::<SymbolicExt<SP1Field, SP1ExtensionField>>();
161 builder.cycle_tracker_v2_exit();
162 let branching_program = BranchingProgram::new(z_row.clone(), z_trace.clone());
163 jagged_eval_sc_expected_eval *=
164 branching_program.eval(&first_half_z_index, &second_half_z_index);
165
166 builder
167 .assert_ext_eq(jagged_eval_sc_expected_eval, partial_sumcheck_proof.point_and_eval.1);
168
169 (jagged_eval.into(), prefix_sum_felts)
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use std::{marker::PhantomData, sync::Arc};
176
177 use rand::{thread_rng, Rng};
178 use slop_algebra::{extension::BinomialExtensionField, AbstractField};
179 use slop_alloc::CpuBackend;
180 use slop_challenger::{DuplexChallenger, IopCtx};
181 use slop_jagged::{
182 JaggedAssistSumAsPolyCPUImpl, JaggedEvalProver, JaggedEvalSumcheckProver,
183 JaggedLittlePolynomialProverParams, JaggedLittlePolynomialVerifierParams,
184 };
185 use slop_multilinear::Point;
186 use sp1_core_machine::utils::setup_logger;
187 use sp1_hypercube::{inner_perm, log2_ceil_usize};
188 use sp1_primitives::{SP1DiffusionMatrix, SP1GlobalContext};
189 use sp1_recursion_compiler::{
190 circuit::{AsmBuilder, AsmCompiler, AsmConfig, CircuitV2Builder},
191 ir::{Ext, Felt},
192 };
193 use sp1_recursion_executor::Executor;
194
195 use crate::{
196 challenger::DuplexChallengerVariable,
197 jagged::jagged_eval::{
198 RecursiveJaggedEvalConfig, RecursiveJaggedEvalSumcheckConfig,
199 RecursiveTrivialJaggedEvalConfig,
200 },
201 witness::Witnessable,
202 SP1FieldConfigVariable,
203 };
204
205 use sp1_primitives::{SP1Field, SP1Perm};
206 type F = SP1Field;
207 type EF = BinomialExtensionField<SP1Field, 4>;
208 type C = AsmConfig;
209 type SC = SP1GlobalContext;
210
211 fn trivial_jagged_eval(
212 verifier_params: &JaggedLittlePolynomialVerifierParams<F>,
213 z_row: &Point<EF>,
214 z_col: &Point<EF>,
215 z_trace: &Point<EF>,
216 expected_result: EF,
217 should_succeed: bool,
218 ) {
219 let mut builder = AsmBuilder::default();
220 builder.cycle_tracker_v2_enter("trivial-jagged-eval");
221 let verifier_params_variable = verifier_params.read(&mut builder);
222 let z_row_variable = z_row.read(&mut builder);
223 let z_col_variable = z_col.read(&mut builder);
224 let z_trace_variable = z_trace.read(&mut builder);
225 let recursive_jagged_evaluator = RecursiveTrivialJaggedEvalConfig {};
226 let (recursive_jagged_evaluation, _) = <RecursiveTrivialJaggedEvalConfig as RecursiveJaggedEvalConfig<C, ()>>::jagged_evaluation(
227 &recursive_jagged_evaluator,
228 &mut builder,
229 &verifier_params_variable,
230 z_row_variable,
231 z_col_variable,
232 z_trace_variable,
233 &(),
234 &mut (),
235 );
236 let recursive_jagged_evaluation: Ext<F, EF> = builder.eval(recursive_jagged_evaluation);
237 let expected_result: Ext<F, EF> = builder.constant(expected_result);
238 builder.assert_ext_eq(recursive_jagged_evaluation, expected_result);
239 builder.cycle_tracker_v2_exit();
240
241 let block = builder.into_root_block();
242 let mut compiler = AsmCompiler::default();
243 let program = compiler.compile_inner(block).validate().unwrap();
244
245 let mut witness_stream = Vec::new();
246 Witnessable::<AsmConfig>::write(&verifier_params, &mut witness_stream);
247 Witnessable::<AsmConfig>::write(&z_row, &mut witness_stream);
248 Witnessable::<AsmConfig>::write(&z_col, &mut witness_stream);
249 Witnessable::<AsmConfig>::write(&z_trace, &mut witness_stream);
250
251 let mut executor =
252 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
253 executor.witness_stream = witness_stream.into();
254 if should_succeed {
255 executor.run().unwrap();
256 } else {
257 executor.run().expect_err("invalid proof should not be verified");
258 }
259 }
260
261 fn sumcheck_jagged_eval(
262 prover_params: &JaggedLittlePolynomialProverParams,
263 verifier_params: &JaggedLittlePolynomialVerifierParams<F>,
264 z_row: &Point<EF>,
265 z_col: &Point<EF>,
266 z_trace: &Point<EF>,
267 expected_result: EF,
268 should_succeed: bool,
269 ) -> Vec<Felt<F>> {
270 let prover = JaggedEvalSumcheckProver::<
271 F,
272 JaggedAssistSumAsPolyCPUImpl<_, _, _>,
273 CpuBackend,
274 <SP1GlobalContext as IopCtx>::Challenger,
275 >::default();
276 let default_perm = inner_perm();
277 let mut challenger =
278 DuplexChallenger::<SP1Field, SP1Perm, 16, 8>::new(default_perm.clone());
279 let jagged_eval_proof = prover.prove_jagged_evaluation(
280 prover_params,
281 z_row,
282 z_col,
283 z_trace,
284 &mut challenger,
285 CpuBackend,
286 );
287
288 let mut builder = AsmBuilder::default();
289 builder.cycle_tracker_v2_enter("sumcheck-jagged-eval");
290 let verifier_params_variable = verifier_params.read(&mut builder);
291 let z_row_variable = z_row.read(&mut builder);
292 let z_col_variable = z_col.read(&mut builder);
293 let z_trace_variable = z_trace.read(&mut builder);
294 let jagged_eval_proof_variable = jagged_eval_proof.read(&mut builder);
295 let recursive_jagged_evaluator = RecursiveJaggedEvalSumcheckConfig::<SC>(PhantomData);
296 let mut challenger_variable = DuplexChallengerVariable::new(&mut builder);
297 let (recursive_jagged_evaluation, prefix_sum_felts) =
298 <RecursiveJaggedEvalSumcheckConfig<SC> as RecursiveJaggedEvalConfig<
299 C,
300 <SC as SP1FieldConfigVariable<C>>::FriChallengerVariable,
301 >>::jagged_evaluation(
302 &recursive_jagged_evaluator,
303 &mut builder,
304 &verifier_params_variable,
305 z_row_variable,
306 z_col_variable,
307 z_trace_variable,
308 &jagged_eval_proof_variable,
309 &mut challenger_variable,
310 );
311 let recursive_jagged_evaluation: Ext<F, EF> = builder.eval(recursive_jagged_evaluation);
312 let expected_result: Ext<F, EF> = builder.constant(expected_result);
313 builder.assert_ext_eq(recursive_jagged_evaluation, expected_result);
314 builder.cycle_tracker_v2_exit();
315
316 let block = builder.into_root_block();
317 let mut compiler = AsmCompiler::default();
318 let program = compiler.compile_inner(block).validate().unwrap();
319
320 let mut witness_stream = Vec::new();
321 Witnessable::<AsmConfig>::write(&verifier_params, &mut witness_stream);
322 Witnessable::<AsmConfig>::write(&z_row, &mut witness_stream);
323 Witnessable::<AsmConfig>::write(&z_col, &mut witness_stream);
324 Witnessable::<AsmConfig>::write(&z_trace, &mut witness_stream);
325 Witnessable::<AsmConfig>::write(&jagged_eval_proof, &mut witness_stream);
326 let mut executor =
327 Executor::<F, EF, SP1DiffusionMatrix>::new(Arc::new(program), inner_perm());
328 executor.witness_stream = witness_stream.into();
329 if should_succeed {
330 executor.run().unwrap();
331 } else {
332 executor.run().expect_err("invalid proof should not be verified");
333 }
334 prefix_sum_felts
335 }
336
337 #[test]
338 fn test_jagged_eval_proof() {
339 setup_logger();
340 let row_counts = [12, 1, 2, 1, 17, 0];
341
342 let mut prefix_sums = row_counts
343 .iter()
344 .scan(0, |state, row_count| {
345 let result = *state;
346 *state += row_count;
347 Some(result)
348 })
349 .collect::<Vec<_>>();
350 prefix_sums.push(*prefix_sums.last().unwrap() + row_counts.last().unwrap());
351
352 let mut rng = thread_rng();
353
354 let log_m = log2_ceil_usize(*prefix_sums.last().unwrap());
355
356 let log_max_row_count = 7;
357
358 let prover_params =
359 JaggedLittlePolynomialProverParams::new(row_counts.to_vec(), log_max_row_count);
360
361 let verifier_params: JaggedLittlePolynomialVerifierParams<F> =
362 prover_params.clone().into_verifier_params();
363
364 let z_row: Point<EF> = (0..log_max_row_count).map(|_| rng.gen::<EF>()).collect();
365 let z_col: Point<EF> =
366 (0..log2_ceil_usize(row_counts.len())).map(|_| rng.gen::<EF>()).collect();
367 let z_trace: Point<EF> = (0..log_m + 1).map(|_| rng.gen::<EF>()).collect();
368
369 let expected_result =
370 verifier_params.full_jagged_little_polynomial_evaluation(&z_row, &z_col, &z_trace);
371
372 trivial_jagged_eval(&verifier_params, &z_row, &z_col, &z_trace, expected_result, true);
373 sumcheck_jagged_eval(
374 &prover_params,
375 &verifier_params,
376 &z_row,
377 &z_col,
378 &z_trace,
379 expected_result,
380 true,
381 );
382
383 let mut z_row_invalid = z_row.clone();
385 let first_element = z_row_invalid.get_mut(0).unwrap();
386 *first_element += EF::one();
387 trivial_jagged_eval(
388 &verifier_params,
389 &z_row_invalid,
390 &z_col,
391 &z_trace,
392 expected_result,
393 false,
394 );
395 sumcheck_jagged_eval(
396 &prover_params,
397 &verifier_params,
398 &z_row_invalid,
399 &z_col,
400 &z_trace,
401 expected_result,
402 false,
403 );
404 }
405}