1use std::sync::Arc;
2
3use rayon::prelude::*;
4use slop_algebra::{
5 interpolate_univariate_polynomial, ExtensionField, Field, UnivariatePolynomial,
6};
7use slop_multilinear::{Mle, Point};
8use slop_sumcheck::{ComponentPoly, SumcheckPoly, SumcheckPolyBase, SumcheckPolyFirstRound};
9
10use super::{InteractionLayer, LogUpGkrCpuLayer};
11
12pub struct LogupRoundPolynomial<F, EF> {
14 pub layer: PolynomialLayer<F, EF>,
16 pub eq_row: Arc<Mle<EF>>,
18 pub eq_interaction: Arc<Mle<EF>>,
20 pub eq_adjustment: EF,
22 pub padding_adjustment: EF,
24 pub lambda: EF,
26 pub point: Point<EF>,
28}
29
30pub enum PolynomialLayer<F, EF> {
32 CircuitLayer(LogUpGkrCpuLayer<F, EF>),
34 InteractionLayer(InteractionLayer<F, EF>),
36}
37
38impl<F: Field, EF: ExtensionField<F>> SumcheckPolyBase for LogupRoundPolynomial<F, EF> {
39 fn num_variables(&self) -> u32 {
40 self.eq_row.num_variables() + self.eq_interaction.num_variables()
41 }
42}
43
44impl<K: Field> ComponentPoly<K> for LogupRoundPolynomial<K, K> {
45 fn get_component_poly_evals(&self) -> Vec<K> {
46 match &self.layer {
47 PolynomialLayer::InteractionLayer(layer) => {
48 assert!(layer.numerator_0.guts().as_slice().len() == 1);
49 let numerator_0 = layer.numerator_0.guts().as_slice()[0];
50 let denominator_0 = layer.denominator_0.guts().as_slice()[0];
51 let numerator_1 = layer.numerator_1.guts().as_slice()[0];
52 let denominator_1 = layer.denominator_1.guts().as_slice()[0];
53 vec![numerator_0, denominator_0, numerator_1, denominator_1]
54 }
55 PolynomialLayer::CircuitLayer(_) => unreachable!(),
56 }
57 }
58}
59
60impl<K: Field> SumcheckPoly<K> for LogupRoundPolynomial<K, K> {
61 fn fix_last_variable(self, alpha: K) -> Self {
62 self.fix_t_variables(alpha, 1)
63 }
64
65 fn sum_as_poly_in_last_variable(&self, claim: Option<K>) -> UnivariatePolynomial<K> {
66 self.sum_as_poly_in_last_t_variables(claim, 1)
67 }
68}
69
70impl<K: ExtensionField<F>, F: Field> SumcheckPolyFirstRound<K> for LogupRoundPolynomial<F, K> {
71 type NextRoundPoly = LogupRoundPolynomial<K, K>;
72 #[allow(clippy::too_many_lines)]
73 fn fix_t_variables(mut self, alpha: K, t: usize) -> Self::NextRoundPoly {
74 assert!(t == 1);
75 let last_coordinate = self.point.remove_last_coordinate();
77 let padding_adjustment = self.padding_adjustment
78 * (last_coordinate * alpha + (K::one() - last_coordinate) * (K::one() - alpha));
79 match self.layer {
80 PolynomialLayer::InteractionLayer(layer) => {
81 let numerator_0 =
82 Arc::new(layer.numerator_0.as_ref().fix_last_variable::<K>(alpha));
83 let denominator_0 =
84 Arc::new(layer.denominator_0.as_ref().fix_last_variable::<K>(alpha));
85 let numerator_1 =
86 Arc::new(layer.numerator_1.as_ref().fix_last_variable::<K>(alpha));
87 let denominator_1 =
88 Arc::new(layer.denominator_1.as_ref().fix_last_variable::<K>(alpha));
89
90 let new_layer =
91 InteractionLayer { numerator_0, denominator_0, numerator_1, denominator_1 };
92
93 let eq_interaction =
94 Arc::new(self.eq_interaction.as_ref().fix_last_variable(alpha));
95
96 LogupRoundPolynomial {
97 layer: PolynomialLayer::InteractionLayer(new_layer),
98 eq_row: self.eq_row,
99 eq_interaction,
100 eq_adjustment: self.eq_adjustment,
101 padding_adjustment,
102 lambda: self.lambda,
103 point: self.point,
104 }
105 }
106 PolynomialLayer::CircuitLayer(layer) => {
107 if layer.num_row_variables == 1 {
108 let numerator_0: Vec<_> = layer
109 .numerator_0
110 .into_iter()
111 .map(|mle| mle.fix_last_variable(alpha))
112 .collect();
113 let denominator_0: Vec<_> = layer
114 .denominator_0
115 .into_iter()
116 .map(|mle| mle.fix_last_variable(alpha))
117 .collect();
118 let numerator_1: Vec<_> = layer
119 .numerator_1
120 .into_iter()
121 .map(|mle| mle.fix_last_variable(alpha))
122 .collect();
123 let denominator_1: Vec<_> = layer
124 .denominator_1
125 .into_iter()
126 .map(|mle| mle.fix_last_variable(alpha))
127 .collect();
128
129 let mut numerator_0_interactions: Vec<_> = numerator_0
130 .into_iter()
131 .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
132 .collect();
133 numerator_0_interactions
134 .resize(1 << layer.num_interaction_variables, K::zero());
135
136 let mut numerator_1_interactions: Vec<_> = numerator_1
137 .into_iter()
138 .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
139 .collect();
140 numerator_1_interactions
141 .resize(1 << layer.num_interaction_variables, K::zero());
142
143 let mut denominator_0_interactions: Vec<_> = denominator_0
144 .into_iter()
145 .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
146 .collect();
147 denominator_0_interactions
148 .resize(1 << layer.num_interaction_variables, K::one());
149
150 let mut denominator_1_interactions: Vec<_> = denominator_1
151 .into_iter()
152 .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
153 .collect();
154 denominator_1_interactions
155 .resize(1 << layer.num_interaction_variables, K::one());
156
157 let numerator_0_mle = Arc::new(Mle::from(numerator_0_interactions));
158 let denominator_0_mle = Arc::new(Mle::from(denominator_0_interactions));
159 let numerator_1_mle = Arc::new(Mle::from(numerator_1_interactions));
160 let denominator_1_mle = Arc::new(Mle::from(denominator_1_interactions));
161
162 let new_layer = InteractionLayer {
163 numerator_0: numerator_0_mle,
164 denominator_0: denominator_0_mle,
165 numerator_1: numerator_1_mle,
166 denominator_1: denominator_1_mle,
167 };
168
169 let eq_row = Arc::new(self.eq_row.as_ref().fix_last_variable(alpha));
170
171 LogupRoundPolynomial {
172 layer: PolynomialLayer::InteractionLayer(new_layer),
173 eq_row,
174 eq_interaction: self.eq_interaction,
175 eq_adjustment: padding_adjustment,
176 padding_adjustment: K::one(),
177 lambda: self.lambda,
178 point: self.point,
179 }
180 } else {
181 let numerator_0: Vec<_> = layer
182 .numerator_0
183 .into_iter()
184 .map(|mle| mle.fix_last_variable(alpha))
185 .collect();
186
187 let denominator_0: Vec<_> = layer
188 .denominator_0
189 .into_iter()
190 .map(|mle| mle.fix_last_variable(alpha))
191 .collect();
192
193 let numerator_1: Vec<_> = layer
194 .numerator_1
195 .into_iter()
196 .map(|mle| mle.fix_last_variable(alpha))
197 .collect();
198
199 let denominator_1: Vec<_> = layer
200 .denominator_1
201 .into_iter()
202 .map(|mle| mle.fix_last_variable(alpha))
203 .collect();
204
205 let eq_row = Arc::new(self.eq_row.as_ref().fix_last_variable(alpha));
206
207 let new_layer = LogUpGkrCpuLayer {
208 numerator_0,
209 denominator_0,
210 numerator_1,
211 denominator_1,
212 num_row_variables: layer.num_row_variables - 1,
213 num_interaction_variables: layer.num_interaction_variables,
214 };
215
216 LogupRoundPolynomial {
217 layer: PolynomialLayer::CircuitLayer(new_layer),
218 eq_row,
219 eq_interaction: self.eq_interaction,
220 eq_adjustment: self.eq_adjustment,
221 padding_adjustment,
222 lambda: self.lambda,
223 point: self.point,
224 }
225 }
226 }
227 }
228 }
229
230 #[allow(clippy::too_many_lines)]
231 fn sum_as_poly_in_last_t_variables(
232 &self,
233 claim: Option<K>,
234 t: usize,
235 ) -> UnivariatePolynomial<K> {
236 assert!(t == 1);
237 let claim = claim.unwrap();
238
239 let (mut eval_zero, mut eval_half, eq_sum) = match &self.layer {
240 PolynomialLayer::InteractionLayer(layer) => {
241 let numerator_0 = layer.numerator_0.clone();
242 let numerator_1 = layer.numerator_1.clone();
243 let denominator_0 = layer.denominator_0.clone();
244 let denominator_1 = layer.denominator_1.clone();
245 let eq_interaction = self.eq_interaction.clone();
246 let lambda = self.lambda;
247 let numerator_eval_0 = numerator_0
248 .guts()
249 .as_slice()
250 .par_iter()
251 .step_by(2)
252 .zip_eq(numerator_1.guts().as_slice().par_iter().step_by(2))
253 .zip_eq(denominator_0.guts().as_slice().par_iter().step_by(2))
254 .zip_eq(denominator_1.guts().as_slice().par_iter().step_by(2))
255 .zip_eq(eq_interaction.guts().as_slice().par_iter().step_by(2))
256 .map(|((((n0, n1), d0), d1), e)| *e * (*d0 * *n1 + *d1 * *n0))
257 .sum::<K>();
258
259 let numerator_eval_half = numerator_0
260 .guts()
261 .as_slice()
262 .par_chunks(2)
263 .zip_eq(numerator_1.guts().as_slice().par_chunks(2))
264 .zip_eq(denominator_0.guts().as_slice().par_chunks(2))
265 .zip_eq(denominator_1.guts().as_slice().par_chunks(2))
266 .zip_eq(eq_interaction.guts().as_slice().par_chunks(2))
267 .map(|((((n0_chunk, n1_chunk), d0_chunk), d1_chunk), e_chunk)| {
268 let n0_half = n0_chunk[0] + n0_chunk[1];
269 let n1_half = n1_chunk[0] + n1_chunk[1];
270 let d0_half = d0_chunk[0] + d0_chunk[1];
271 let d1_half = d1_chunk[0] + d1_chunk[1];
272 let e_half = e_chunk[0] + e_chunk[1];
273 e_half * (d0_half * n1_half + d1_half * n0_half)
274 })
275 .sum::<K>();
276
277 let denominator_eval_0 = denominator_0
278 .guts()
279 .as_slice()
280 .par_iter()
281 .step_by(2)
282 .zip_eq(denominator_1.guts().as_slice().par_iter().step_by(2))
283 .zip_eq(eq_interaction.guts().as_slice().par_iter().step_by(2))
284 .map(|((d0, d1), e)| *e * (*d0 * *d1))
285 .sum::<K>();
286
287 let denominator_eval_half = denominator_0
288 .guts()
289 .as_slice()
290 .par_chunks(2)
291 .zip_eq(denominator_1.guts().as_slice().par_chunks(2))
292 .zip_eq(eq_interaction.guts().as_slice().par_chunks(2))
293 .map(|((d0_chunk, d1_chunk), e_chunk)| {
294 let d0_half = d0_chunk[0] + d0_chunk[1];
295 let d1_half = d1_chunk[0] + d1_chunk[1];
296 let e_half = e_chunk[0] + e_chunk[1];
297 e_half * (d0_half * d1_half)
298 })
299 .sum::<K>();
300
301 let eq_half_sum = eq_interaction
302 .guts()
303 .as_slice()
304 .par_chunks(2)
305 .map(|e_chunk| e_chunk[0] + e_chunk[1])
306 .sum::<K>();
307
308 (
309 lambda * numerator_eval_0 + denominator_eval_0,
310 lambda * numerator_eval_half + denominator_eval_half,
311 eq_half_sum,
312 )
313 }
314 PolynomialLayer::CircuitLayer(layer) => {
315 let numerator_0 = layer.numerator_0.clone();
316 let numerator_1 = layer.numerator_1.clone();
317 let denominator_0 = layer.denominator_0.clone();
318 let denominator_1 = layer.denominator_1.clone();
319 let eq_row = self.eq_row.clone();
320 assert!(eq_row.num_non_zero_entries().is_multiple_of(2));
322 let eq_interaction = self.eq_interaction.clone();
323 let lambda = self.lambda;
324
325 let mut interaction_offset = 0;
326 let mut eval_0 = K::zero();
327 let mut eval_half = K::zero();
328 let mut eq_sum = K::zero();
329 for (numerator_0, numerator_1, denominator_0, denominator_1) in
330 itertools::izip!(numerator_0, numerator_1, denominator_0, denominator_1)
331 {
332 if let Some(inner) = numerator_0.inner() {
333 assert!(numerator_0.num_variables() > 0);
334 let numerator_1_inner = numerator_1.inner().as_ref().unwrap();
335 let denominator_0_inner = denominator_0.inner().as_ref().unwrap();
340 let denominator_1_inner = denominator_1.inner().as_ref().unwrap();
341 let (eval_0_chip, eval_half_chip, eq_sum_chip) =
342 inner
343 .guts()
344 .as_slice()
345 .par_chunks(2 * numerator_0.num_polynomials())
346 .zip_eq(
347 numerator_1_inner
348 .guts()
349 .as_slice()
350 .par_chunks(2 * numerator_1_inner.num_polynomials()),
351 )
352 .zip_eq(
353 denominator_0_inner
354 .guts()
355 .as_slice()
356 .par_chunks(2 * denominator_0_inner.num_polynomials()),
357 )
358 .zip_eq(
359 denominator_1_inner
360 .guts()
361 .as_slice()
362 .par_chunks(2 * denominator_1_inner.num_polynomials()),
363 )
364 .zip(eq_row.guts().as_slice().par_chunks(2))
365 .map(
366 |(
367 (((numer_0_row, numer_1_row), denom_0_row), denom_1_row),
368 eq_row_chunk,
369 )| {
370 let eq_interactions_chip = eq_interaction.guts().as_slice()
371 [interaction_offset
372 ..interaction_offset
373 + numerator_0.num_polynomials()]
374 .par_iter();
375
376 let (numer_0_row_0, numer_0_row_1) =
377 numer_0_row.split_at(numerator_0.num_polynomials());
378 let (denom_0_row_0, denom_0_row_1) =
379 denom_0_row.split_at(denominator_0.num_polynomials());
380 let (denom_1_row_0, denom_1_row_1) =
381 denom_1_row.split_at(denominator_1.num_polynomials());
382 let (numer_1_row_0, numer_1_row_1) =
383 numer_1_row.split_at(numerator_1.num_polynomials());
384 let eq_row_0 = eq_row_chunk[0];
385 let eq_row_1 = eq_row_chunk[1];
386 if numer_0_row.len() == 2 * numerator_0.num_polynomials() {
387 let numerator_0_eval = numer_0_row_0
388 .par_iter()
389 .zip_eq(numer_1_row_0.par_iter())
390 .zip_eq(denom_0_row_0.par_iter())
391 .zip_eq(denom_1_row_0.par_iter())
392 .zip_eq(eq_interactions_chip.clone())
393 .map(|((((n0, n1), d0), d1), e)| {
394 *e * (*d0 * *n1 + *d1 * *n0)
396 })
397 .sum::<K>();
398 let denominator_0_eval = denom_0_row_0
399 .par_iter()
400 .zip_eq(denom_1_row_0.par_iter())
401 .zip_eq(eq_interactions_chip.clone())
402 .map(|((d0, d1), e)| *e * (*d0 * *d1))
403 .sum::<K>();
404 let numerator_half_eval = numer_0_row_0
405 .par_iter()
406 .zip_eq(numer_1_row_0.par_iter())
407 .zip_eq(denom_0_row_0.par_iter())
408 .zip_eq(denom_1_row_0.par_iter())
409 .zip_eq(numer_0_row_1.par_iter())
410 .zip_eq(numer_1_row_1.par_iter())
411 .zip_eq(denom_0_row_1.par_iter())
412 .zip_eq(denom_1_row_1.par_iter())
413 .zip_eq(eq_interactions_chip.clone())
414 .map(
415 |(((
416 (
417 (
418 (((n0_0, n1_0), d0_0), d1_0),
419 n0_1,
420 ),
421 n1_1,
422 ),
423 d0_1), d1_1),
424 e,
425 )| {
426 *e * ((*d0_0 + *d0_1) * (*n1_0 + *n1_1)
427 + (*d1_0 + *d1_1) * (*n0_0 + *n0_1))
428 },
429 )
430 .sum::<K>();
431 let denominator_half_eval = denom_0_row_0
432 .par_iter()
433 .zip_eq(denom_1_row_0.par_iter())
434 .zip_eq(denom_0_row_1.par_iter())
435 .zip_eq(denom_1_row_1.par_iter())
436 .zip_eq(eq_interactions_chip.clone())
437 .map(|((((d0_0, d1_0), d0_1), d1_1), e)| {
438 *e * ((*d0_0 + *d0_1) * (*d1_0 + *d1_1))
439 })
440 .sum::<K>();
441 let eq_interactions_chip_half = eq_interactions_chip
442 .map(|e| *e * (eq_row_0 + eq_row_1))
443 .sum::<K>();
444 (
445 (lambda * numerator_0_eval + denominator_0_eval)
446 * eq_row_0,
447 (lambda * numerator_half_eval
448 + denominator_half_eval)
449 * (eq_row_0 + eq_row_1),
450 eq_interactions_chip_half,
451 )
452 } else {
453 let numerator_0_eval = numer_0_row_0
454 .par_iter()
455 .zip_eq(numer_1_row_0.par_iter())
456 .zip_eq(denom_0_row_0.par_iter())
457 .zip_eq(denom_1_row_0.par_iter())
458 .zip_eq(eq_interactions_chip.clone())
459 .map(|((((n0, n1), d0), d1), e)| {
460 *e * (*d0 * *n1 + *d1 * *n0)
461 })
462 .sum::<K>();
463 let denominator_0_eval = denom_0_row_0
464 .par_iter()
465 .zip_eq(denom_1_row_0.par_iter())
466 .zip_eq(eq_interactions_chip.clone())
467 .map(|((d0, d1), e)| *e * (*d0 * *d1))
468 .sum::<K>();
469 let numerator_half_eval = numer_0_row_0
470 .par_iter()
471 .zip_eq(numer_1_row_0.par_iter())
472 .zip_eq(denom_0_row_0.par_iter())
473 .zip_eq(denom_1_row_0.par_iter())
474 .zip_eq(eq_interactions_chip.clone())
475 .map(|((((n0, n1), d0), d1), e)| {
476 *e * ((*d0 + K::one()) * *n1
477 + (*d1 + K::one()) * *n0)
478 })
479 .sum::<K>();
480 let denominator_half_eval = denom_0_row_0
481 .par_iter()
482 .zip_eq(denom_1_row_0.par_iter())
483 .zip_eq(eq_interactions_chip.clone())
484 .map(|((d0, d1), e)| {
485 *e * ((*d0 + K::one()) * (*d1 + K::one()))
486 })
487 .sum::<K>();
488 let eq_interactions_chip_half = eq_interactions_chip
489 .map(|e| *e * (eq_row_0 + eq_row_1))
490 .sum::<K>();
491 (
492 (lambda * numerator_0_eval + denominator_0_eval)
493 * eq_row_0,
494 (lambda * numerator_half_eval
495 + denominator_half_eval)
496 * (eq_row_0 + eq_row_1),
497 eq_interactions_chip_half,
498 )
499 }
500 },
501 )
502 .reduce(
503 || (K::zero(), K::zero(), K::zero()),
504 |(y_0_acc, y_half_acc, eq_sum_acc), (y_0, y_half, eq_sum)| {
505 (y_0_acc + y_0, y_half_acc + y_half, eq_sum_acc + eq_sum)
506 },
507 );
508 eval_0 += eval_0_chip;
509 eval_half += eval_half_chip;
510 eq_sum += eq_sum_chip;
511 }
512 interaction_offset += numerator_0.num_polynomials();
513 }
515
516 (eval_0, eval_half, eq_sum)
517 }
518 };
519
520 let eq_correction_term = self.padding_adjustment - eq_sum;
524 eval_zero += eq_correction_term * (K::one() - *self.point.last().unwrap());
527 eval_half += eq_correction_term * K::from_canonical_u16(4);
530
531 let eval_half = eval_half * K::from_canonical_u16(8).inverse();
534
535 let eval_zero = eval_zero * self.eq_adjustment;
536 let eval_half = eval_half * self.eq_adjustment;
537
538 let point_last = self.point.last().unwrap();
540 let b_const = (K::one() - *point_last) / (K::one() - point_last.double());
541
542 let eval_one = claim - eval_zero;
543
544 interpolate_univariate_polynomial(
545 &[
546 K::from_canonical_u16(0),
547 K::from_canonical_u16(1),
548 K::from_canonical_u16(2).inverse(),
549 b_const,
550 ],
551 &[eval_zero, eval_one, eval_half, K::zero()],
552 )
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use crate::{prove_gkr_round, GkrCircuitLayer, LogupGkrCpuTraceGenerator};
559
560 use super::*;
561 use itertools::Itertools;
562 use rand::{thread_rng, Rng};
563 use slop_algebra::{extension::BinomialExtensionField, AbstractField};
564 use slop_alloc::CpuBackend;
565
566 use slop_challenger::{FieldChallenger, IopCtx};
567 use slop_matrix::dense::RowMajorMatrix;
568 use slop_multilinear::{PaddedMle, Padding};
569 use slop_sumcheck::{partially_verify_sumcheck_proof, reduce_sumcheck_to_evaluation};
570 use slop_tensor::Tensor;
571 use sp1_primitives::SP1Field;
572
573 type EF = BinomialExtensionField<SP1Field, 4>;
574 type F = SP1Field;
575
576 fn random_layer(
577 rng: &mut impl Rng,
578 interaction_counts: &[usize],
579 num_rows: usize,
580 num_row_variables: usize,
581 num_interaction_variables: usize,
582 ) -> LogUpGkrCpuLayer<F, EF> {
583 let numerator_0 = interaction_counts
584 .iter()
585 .map(|count| {
586 let guts = Tensor::<F>::rand(rng, [num_rows, *count]);
587 Mle::new(guts)
588 })
589 .collect::<Vec<_>>();
590 let denominator_0 = interaction_counts
591 .iter()
592 .map(|count| {
593 let guts = Tensor::<EF>::rand(rng, [num_rows, *count]);
594 Mle::new(guts)
595 })
596 .collect::<Vec<_>>();
597 let numerator_1 = interaction_counts
598 .iter()
599 .map(|count| {
600 let guts = Tensor::<F>::rand(rng, [num_rows, *count]);
601 Mle::new(guts)
602 })
603 .collect::<Vec<_>>();
604 let denominator_1 = interaction_counts
605 .iter()
606 .map(|count| {
607 let guts = Tensor::<EF>::rand(rng, [num_rows, *count]);
608 Mle::new(guts)
609 })
610 .collect::<Vec<_>>();
611
612 let padded_numerator_0 = numerator_0
613 .iter()
614 .map(|mle| {
615 PaddedMle::padded_with_zeros(Arc::new(mle.clone()), num_row_variables as u32)
616 })
617 .collect::<Vec<_>>();
618
619 let padded_denominator_0 = denominator_0
620 .iter()
621 .map(|mle| {
622 let num_polys = mle.num_polynomials();
623 PaddedMle::padded(
624 Arc::new(mle.clone()),
625 num_row_variables as u32,
626 Padding::Constant((EF::one(), num_polys, CpuBackend)),
627 )
628 })
629 .collect::<Vec<_>>();
630
631 let padded_numerator_1 = numerator_1
632 .iter()
633 .map(|mle| {
634 PaddedMle::padded_with_zeros(Arc::new(mle.clone()), num_row_variables as u32)
635 })
636 .collect::<Vec<_>>();
637 let padded_denominator_1 = denominator_1
638 .iter()
639 .map(|mle| {
640 let num_polys = mle.num_polynomials();
641 PaddedMle::padded(
642 Arc::new(mle.clone()),
643 num_row_variables as u32,
644 Padding::Constant((EF::one(), num_polys, CpuBackend)),
645 )
646 })
647 .collect::<Vec<_>>();
648
649 LogUpGkrCpuLayer {
650 numerator_0: padded_numerator_0,
651 denominator_0: padded_denominator_0,
652 numerator_1: padded_numerator_1,
653 denominator_1: padded_denominator_1,
654 num_row_variables,
655 num_interaction_variables,
656 }
657 }
658
659 #[test]
660 #[allow(clippy::too_many_lines)]
661 fn test_logup_poly_fix_last_variable() {
662 let mut rng = thread_rng();
663 let interaction_counts = vec![1];
664 let num_rows: usize = 4;
665 let num_row_variables = 2;
666 let num_interaction_variables =
667 interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
668 let layer = random_layer(
669 &mut rng,
670 &interaction_counts,
671 num_rows,
672 num_row_variables as usize,
673 num_interaction_variables as usize,
674 );
675
676 let poly_point = Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
677 let (interaction_point, row_point) =
678 poly_point.split_at(num_interaction_variables as usize);
679
680 let random_point =
681 Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
682 let (interaction_random_point, row_random_point) =
683 random_point.split_at(num_interaction_variables as usize);
684
685 let lambda = rng.gen::<EF>();
686 let eq_row = Mle::partial_lagrange(&row_point);
687 let eq_interaction = Mle::partial_lagrange(&interaction_point);
688
689 let first_polynomial = LogupRoundPolynomial {
690 layer: PolynomialLayer::CircuitLayer(layer),
691 eq_row: Arc::new(eq_row),
692 eq_interaction: Arc::new(eq_interaction),
693 eq_adjustment: EF::one(),
694 padding_adjustment: EF::one(),
695 lambda,
696 point: poly_point,
697 };
698
699 let PolynomialLayer::CircuitLayer(layer) = &first_polynomial.layer else {
700 panic!("first polynomial is not a circuit layer");
701 };
702
703 let mut numerator_0_interactions: Vec<EF> = layer
704 .numerator_0
705 .iter()
706 .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
707 .collect();
708 numerator_0_interactions.resize(1 << layer.num_interaction_variables, EF::zero());
709
710 let mut numerator_1_interactions: Vec<EF> = layer
711 .numerator_1
712 .iter()
713 .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
714 .collect();
715 numerator_1_interactions.resize(1 << layer.num_interaction_variables, EF::zero());
716
717 let mut denominator_0_interactions: Vec<EF> = layer
718 .denominator_0
719 .iter()
720 .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
721 .collect();
722 denominator_0_interactions.resize(1 << layer.num_interaction_variables, EF::one());
723
724 let mut denominator_1_interactions: Vec<EF> = layer
725 .denominator_1
726 .iter()
727 .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
728 .collect();
729 denominator_1_interactions.resize(1 << layer.num_interaction_variables, EF::one());
730
731 let mut round_polynomial =
733 first_polynomial.fix_t_variables(*row_random_point.last().unwrap(), 1);
734
735 for alpha in row_random_point.iter().rev().skip(1) {
736 round_polynomial = round_polynomial.fix_t_variables(*alpha, 1);
737 }
738
739 let PolynomialLayer::InteractionLayer(interaction_layer) = &round_polynomial.layer else {
740 panic!("round polynomial is not an interaction layer");
741 };
742
743 for (i, numerator_0_interaction) in numerator_0_interactions.iter().enumerate() {
745 assert_eq!(
746 *numerator_0_interaction,
747 interaction_layer.numerator_0.guts().as_slice()[i]
748 );
749 }
750 for (i, numerator_1_interaction) in numerator_1_interactions.iter().enumerate() {
751 assert_eq!(
752 *numerator_1_interaction,
753 interaction_layer.numerator_1.guts().as_slice()[i]
754 );
755 }
756 for (i, denominator_0_interaction) in denominator_0_interactions.iter().enumerate() {
757 assert_eq!(
758 *denominator_0_interaction,
759 interaction_layer.denominator_0.guts().as_slice()[i]
760 );
761 }
762 for (i, denominator_1_interaction) in denominator_1_interactions.iter().enumerate() {
763 assert_eq!(
764 *denominator_1_interaction,
765 interaction_layer.denominator_1.guts().as_slice()[i]
766 );
767 }
768
769 let numerator_0_eval = interaction_layer.numerator_0.eval_at(&interaction_random_point)[0];
771 let numerator_1_eval = interaction_layer.numerator_1.eval_at(&interaction_random_point)[0];
772 let denominator_0_eval =
773 interaction_layer.denominator_0.eval_at(&interaction_random_point)[0];
774 let denominator_1_eval =
775 interaction_layer.denominator_1.eval_at(&interaction_random_point)[0];
776
777 for alpha in interaction_random_point.iter().rev() {
779 round_polynomial = round_polynomial.fix_t_variables(*alpha, 1);
780 }
781
782 let [n0, d0, n1, d1] = round_polynomial.get_component_poly_evals().try_into().unwrap();
783
784 assert_eq!(numerator_0_eval, n0);
785 assert_eq!(numerator_1_eval, n1);
786 assert_eq!(denominator_0_eval, d0);
787 assert_eq!(denominator_1_eval, d1);
788 }
789
790 #[test]
791 #[allow(clippy::too_many_lines)]
792 fn test_logup_poly_sumcheck_circuit_layer() {
793 type GC = sp1_primitives::SP1GlobalContext;
794 let mut rng = thread_rng();
795
796 let get_challenger = move || GC::default_challenger();
797
798 let interaction_counts = vec![4, 5, 6];
799 let num_rows: usize = 8;
800 let num_row_variables = 4;
801
802 let num_interaction_variables =
803 interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
804 let layer = random_layer(
805 &mut rng,
806 &interaction_counts,
807 num_rows,
808 num_row_variables as usize,
809 num_interaction_variables as usize,
810 );
811
812 let poly_point = Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
813 let (interaction_point, row_point) =
814 poly_point.split_at(num_interaction_variables as usize);
815
816 let eq_row = Mle::partial_lagrange(&row_point);
817 let eq_interaction = Mle::partial_lagrange(&interaction_point);
818
819 let numerator_0 = layer.numerator_0.clone();
820 let numerator_1 = layer.numerator_1.clone();
821 let denominator_0 = layer.denominator_0.clone();
822 let denominator_1 = layer.denominator_1.clone();
823 let lambda = rng.gen::<EF>();
824
825 let round_polynomial = LogupRoundPolynomial {
826 layer: PolynomialLayer::CircuitLayer(layer),
827 eq_row: Arc::new(eq_row),
828 eq_interaction: Arc::new(eq_interaction),
829 eq_adjustment: EF::one(),
830 padding_adjustment: EF::one(),
831 lambda,
832 point: poly_point.clone(),
833 };
834
835 let total_eq = Mle::partial_lagrange(&poly_point);
836
837 let total_eq_guts = total_eq.guts().as_slice().to_vec().clone();
838
839 let claim = {
840 let mut offset = 0;
841 let real_claim = numerator_0
842 .iter()
843 .zip_eq(numerator_1.iter())
844 .zip_eq(denominator_0.iter())
845 .zip_eq(denominator_1.iter())
846 .map(|(((n_0, n_1), d_0), d_1)| {
847 let num_padding = vec![
849 F::zero();
850 ((1 << num_row_variables) - num_rows)
851 * n_0.num_polynomials()
852 ];
853 let den_padding = vec![
854 EF::one();
855 ((1 << num_row_variables) - num_rows)
856 * d_0.num_polynomials()
857 ];
858
859 let padded_n0 = n_0
860 .inner()
861 .as_ref()
862 .unwrap()
863 .guts()
864 .as_slice()
865 .iter()
866 .copied()
867 .chain(num_padding.iter().copied())
868 .collect::<Vec<_>>();
869 let padded_n1 = n_1
870 .inner()
871 .as_ref()
872 .unwrap()
873 .guts()
874 .as_slice()
875 .iter()
876 .copied()
877 .chain(num_padding.iter().copied())
878 .collect::<Vec<_>>();
879 let padded_d0 = d_0
880 .inner()
881 .as_ref()
882 .unwrap()
883 .guts()
884 .as_slice()
885 .iter()
886 .copied()
887 .chain(den_padding.iter().copied())
888 .collect::<Vec<_>>();
889 let padded_d1 = d_1
890 .inner()
891 .as_ref()
892 .unwrap()
893 .guts()
894 .as_slice()
895 .iter()
896 .copied()
897 .chain(den_padding.iter().copied())
898 .collect::<Vec<_>>();
899 let padded_d0 =
900 Mle::from(RowMajorMatrix::new(padded_d0, d_0.num_polynomials()));
901 let padded_d1 =
902 Mle::from(RowMajorMatrix::new(padded_d1, d_1.num_polynomials()));
903 let padded_n0 =
904 Mle::from(RowMajorMatrix::new(padded_n0, n_0.num_polynomials()));
905 let padded_n1 =
906 Mle::from(RowMajorMatrix::new(padded_n1, n_1.num_polynomials()));
907
908 let result = padded_n0
909 .guts()
910 .transpose()
911 .as_slice()
912 .iter()
913 .zip_eq(padded_n1.guts().transpose().as_slice().iter())
914 .zip_eq(padded_d0.guts().transpose().as_slice().iter())
915 .zip_eq(padded_d1.guts().transpose().as_slice().iter())
916 .zip(total_eq_guts.iter().skip(offset))
917 .map(|((((n_0, n_1), d_0), d_1), e)| {
918 let numerator_eval = *d_1 * *n_0 + *d_0 * *n_1;
919 let denominator_eval = *d_0 * *d_1;
920 *e * (numerator_eval * lambda + denominator_eval)
921 })
922 .sum::<EF>();
923
924 offset += padded_n0.guts().as_slice().len();
925 result
926 })
927 .sum::<EF>();
928 let remaining_eq = total_eq_guts.iter().copied().skip(offset).sum::<EF>();
929 real_claim + remaining_eq
930 };
931
932 let mut challenger = get_challenger();
933 let (proof, evals) = reduce_sumcheck_to_evaluation(
934 vec![round_polynomial],
935 &mut challenger,
936 vec![claim],
937 1,
938 EF::one(),
939 );
940
941 let mut challenger = get_challenger();
942 partially_verify_sumcheck_proof(
943 &proof,
944 &mut challenger,
945 (num_row_variables + num_interaction_variables) as usize,
946 3,
947 )
948 .unwrap();
949
950 let (point, expected_final_eval) = proof.point_and_eval;
951
952 assert_eq!(point.dimension() as u32, num_row_variables + num_interaction_variables);
954
955 let [evals] = evals.try_into().unwrap();
957 assert_eq!(evals.len(), 4);
958 let [n_0, d_0, n_1, d_1] = evals.try_into().unwrap();
959
960 let eq_eval = Mle::full_lagrange_eval(&poly_point, &point);
961
962 let expected_numerator_eval = n_0 * d_1 + n_1 * d_0;
963 let expected_denominator_eval = d_0 * d_1;
964 let eval = expected_numerator_eval * lambda + expected_denominator_eval;
965 let final_eval = eq_eval * eval;
966
967 assert_eq!(final_eval, expected_final_eval);
969 }
970
971 #[test]
972 #[allow(clippy::too_many_lines)]
973 fn test_logup_gkr_circuit_transition() {
974 type TraceGenerator = LogupGkrCpuTraceGenerator<SP1Field, EF, ()>;
975 let mut rng = thread_rng();
976
977 let trace_generator = TraceGenerator::default();
978
979 let interaction_counts = vec![4, 5, 6];
980 let num_rows: usize = 8;
981 let num_row_variables = 4;
982 let num_interaction_variables =
983 interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
984 let layer = random_layer(
985 &mut rng,
986 &interaction_counts,
987 num_rows,
988 num_row_variables as usize,
989 num_interaction_variables as usize,
990 );
991 let next_layer = trace_generator.layer_transition(&layer);
992
993 let curr_numerator_0 = layer.numerator_0;
994 let curr_numerator_1 = layer.numerator_1;
995 let curr_denominator_0 = layer.denominator_0;
996 let curr_denominator_1 = layer.denominator_1;
997
998 let next_numerator_0 = next_layer.numerator_0;
999 let next_numerator_1 = next_layer.numerator_1;
1000 let next_denominator_0 = next_layer.denominator_0;
1001 let next_denominator_1 = next_layer.denominator_1;
1002
1003 for (next_n0, next_n1, next_d0, next_d1, curr_n0, curr_n1, curr_d0, curr_d1) in itertools::izip!(
1004 next_numerator_0.iter(),
1005 next_numerator_1.iter(),
1006 next_denominator_0.iter(),
1007 next_denominator_1.iter(),
1008 curr_numerator_0.iter(),
1009 curr_numerator_1.iter(),
1010 curr_denominator_0.iter(),
1011 curr_denominator_1.iter()
1012 ) {
1013 let next_n1_inner = next_n1.inner().as_ref().unwrap();
1014 let next_n0_inner = next_n0.inner().as_ref().unwrap();
1015 let next_d0_inner = next_d0.inner().as_ref().unwrap();
1016 let next_d1_inner = next_d1.inner().as_ref().unwrap();
1017 let curr_n0_inner = curr_n0.inner().as_ref().unwrap();
1018 let curr_n1_inner = curr_n1.inner().as_ref().unwrap();
1019 let curr_d0_inner = curr_d0.inner().as_ref().unwrap();
1020 let curr_d1_inner = curr_d1.inner().as_ref().unwrap();
1021 let _ = next_n0_inner
1022 .guts()
1023 .transpose()
1024 .as_slice()
1025 .chunks(next_n0.num_real_entries())
1026 .zip_eq(
1027 next_n1_inner.guts().transpose().as_slice().chunks(next_n1.num_real_entries()),
1028 )
1029 .zip_eq(
1030 curr_n0_inner
1031 .guts()
1032 .transpose()
1033 .as_slice()
1034 .chunks(curr_n0.num_real_entries())
1035 .zip_eq(
1036 curr_n1_inner
1037 .guts()
1038 .transpose()
1039 .as_slice()
1040 .chunks(curr_n1.num_real_entries()),
1041 ),
1042 )
1043 .zip_eq(
1044 curr_d0_inner
1045 .guts()
1046 .transpose()
1047 .as_slice()
1048 .chunks(curr_d0.num_real_entries())
1049 .zip_eq(
1050 curr_d1_inner
1051 .guts()
1052 .transpose()
1053 .as_slice()
1054 .chunks(curr_d1.num_real_entries()),
1055 ),
1056 )
1057 .map(
1058 |(
1059 ((n0_col, n1_col), (curr_n0_col, curr_n1_col)),
1060 (curr_d0_col, curr_d1_col),
1061 )| {
1062 let next_n = n0_col.iter().interleave(n1_col.iter()).collect::<Vec<_>>();
1063 for (
1064 i,
1065 ((((next_n_val, curr_n0_val), curr_n1_val), curr_d0_val), curr_d1_val),
1066 ) in next_n
1067 .iter()
1068 .copied()
1069 .zip_eq(curr_n0_col.iter())
1070 .zip_eq(curr_n1_col.iter())
1071 .zip_eq(curr_d0_col.iter())
1072 .zip_eq(curr_d1_col.iter())
1073 .enumerate()
1074 {
1075 assert_eq!(
1076 *next_n_val,
1077 *curr_d1_val * *curr_n0_val + *curr_d0_val * *curr_n1_val,
1078 "failed at index {i}"
1079 );
1080 }
1081 },
1082 );
1083 let _ = next_d0_inner
1084 .guts()
1085 .transpose()
1086 .as_slice()
1087 .chunks(next_d0.num_real_entries())
1088 .zip_eq(
1089 next_d1_inner.guts().transpose().as_slice().chunks(next_d1.num_real_entries()),
1090 )
1091 .zip_eq(
1092 curr_d0_inner
1093 .guts()
1094 .transpose()
1095 .as_slice()
1096 .chunks(curr_d0.num_real_entries())
1097 .zip_eq(
1098 curr_d1_inner
1099 .guts()
1100 .transpose()
1101 .as_slice()
1102 .chunks(curr_d1.num_real_entries()),
1103 ),
1104 )
1105 .map(|((next_d0_col, next_d1_col), (curr_d0_col, curr_d1_col))| {
1106 let next_d =
1107 next_d0_col.iter().interleave(next_d1_col.iter()).collect::<Vec<_>>();
1108 for (i, ((next_d_val, curr_d0_val), curr_d1_val)) in next_d
1109 .iter()
1110 .copied()
1111 .zip_eq(curr_d0_col.iter())
1112 .zip_eq(curr_d1_col.iter())
1113 .enumerate()
1114 {
1115 assert_eq!(*next_d_val, *curr_d0_val * *curr_d1_val, "failed at index {i}");
1116 }
1117 });
1118 }
1119 }
1120
1121 #[test]
1122 fn test_logup_gkr_round_prover() {
1123 type GC = sp1_primitives::SP1GlobalContext;
1124 type TraceGenerator = LogupGkrCpuTraceGenerator<SP1Field, EF, ()>;
1125 let get_challenger = move || GC::default_challenger();
1126 let trace_generator = TraceGenerator::default();
1127
1128 let mut rng = thread_rng();
1129
1130 let interaction_counts = vec![4, 5, 6];
1131 let num_interaction_variables =
1132 interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
1133 let num_rows: usize = 32;
1134 let num_row_variables = 7;
1135 let input_layer = random_layer(
1136 &mut rng,
1137 &interaction_counts,
1138 num_rows,
1139 num_row_variables as usize,
1140 num_interaction_variables as usize,
1141 );
1142
1143 let first_eval_point = Point::<EF>::rand(&mut rng, num_interaction_variables + 1);
1144
1145 let layer = GkrCircuitLayer::FirstLayer(input_layer);
1146
1147 let mut layers = vec![layer];
1148 for _ in 0..num_row_variables - 1 {
1149 let next_layer = match layers.last().unwrap() {
1150 GkrCircuitLayer::Layer(layer) => trace_generator.layer_transition(layer),
1151 GkrCircuitLayer::FirstLayer(layer) => trace_generator.layer_transition(layer),
1152 };
1153 layers.push(GkrCircuitLayer::Layer(next_layer));
1154 }
1155 layers.reverse();
1156
1157 let GkrCircuitLayer::Layer(first_layer) = layers.first().unwrap() else {
1158 panic!("first layer not correct");
1159 };
1160
1161 let output = trace_generator.extract_outputs(first_layer);
1162 assert_eq!(output.numerator.num_variables(), num_interaction_variables + 1);
1163 assert_eq!(output.denominator.num_variables(), num_interaction_variables + 1);
1164
1165 let first_numerator_eval = output.numerator.eval_at(&first_eval_point)[0];
1166 let first_denominator_eval = output.denominator.eval_at(&first_eval_point)[0];
1167
1168 let mut challenger = get_challenger();
1169 let mut round_proofs = Vec::new();
1170 let mut numerator_eval = first_numerator_eval;
1171 let mut denominator_eval = first_denominator_eval;
1172 let mut eval_point = first_eval_point.clone();
1173
1174 for layer in layers {
1175 let round_proof = prove_gkr_round(
1176 layer,
1177 &eval_point,
1178 numerator_eval,
1179 denominator_eval,
1180 &mut challenger,
1181 );
1182 challenger.observe_ext_element(round_proof.numerator_0);
1184 challenger.observe_ext_element(round_proof.denominator_0);
1185 challenger.observe_ext_element(round_proof.numerator_1);
1186 challenger.observe_ext_element(round_proof.denominator_1);
1187 eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
1189 let last_coordinate = challenger.sample_ext_element::<EF>();
1191
1192 numerator_eval = round_proof.numerator_0
1194 + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
1195 denominator_eval = round_proof.denominator_0
1196 + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
1197 eval_point.add_dimension_back(last_coordinate);
1198 round_proofs.push(round_proof);
1200 }
1201
1202 let mut challenger = get_challenger();
1204 let mut numerator_eval = first_numerator_eval;
1205 let mut denominator_eval = first_denominator_eval;
1206 let mut eval_point = first_eval_point;
1207 for (i, round_proof) in round_proofs.iter().enumerate() {
1208 let lambda = challenger.sample_ext_element::<EF>();
1210 let expected_claim = numerator_eval * lambda + denominator_eval;
1212 assert_eq!(round_proof.sumcheck_proof.claimed_sum, expected_claim);
1213
1214 partially_verify_sumcheck_proof(
1216 &round_proof.sumcheck_proof,
1217 &mut challenger,
1218 i + num_interaction_variables as usize + 1,
1219 3,
1220 )
1221 .unwrap();
1222
1223 let (point, final_eval) = round_proof.sumcheck_proof.point_and_eval.clone();
1225 let eq_eval = Mle::full_lagrange_eval(&point, &eval_point);
1226 let numerator_sumcheck_eval = round_proof.numerator_0 * round_proof.denominator_1
1227 + round_proof.numerator_1 * round_proof.denominator_0;
1228 let denominator_sumcheck_eval = round_proof.denominator_0 * round_proof.denominator_1;
1229 let expected_final_eval =
1230 eq_eval * (numerator_sumcheck_eval * lambda + denominator_sumcheck_eval);
1231
1232 assert_eq!(final_eval, expected_final_eval, "failed at index {i}");
1233
1234 challenger.observe_ext_element(round_proof.numerator_0);
1236 challenger.observe_ext_element(round_proof.denominator_0);
1237 challenger.observe_ext_element(round_proof.numerator_1);
1238 challenger.observe_ext_element(round_proof.denominator_1);
1239
1240 eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
1242
1243 let last_coordinate = challenger.sample_ext_element::<EF>();
1245 eval_point.add_dimension_back(last_coordinate);
1246 numerator_eval = round_proof.numerator_0
1248 + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
1249 denominator_eval = round_proof.denominator_0
1250 + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
1251 }
1252 }
1253}