Skip to main content

sp1_hypercube/logup_gkr/
logup_poly.rs

1use std::sync::Arc;
2
3use rayon::prelude::*;
4use slop_algebra::{
5    interpolate_univariate_polynomial, ExtensionField, Field, UnivariatePolynomial,
6};
7use slop_multilinear::{Mle, Point};
8use slop_sumcheck::{ComponentPoly, SumcheckPoly, SumcheckPolyBase, SumcheckPolyFirstRound};
9
10use super::{InteractionLayer, LogUpGkrCpuLayer};
11
12/// Polynomial representing a round of the GKR circuit.
13pub struct LogupRoundPolynomial<F, EF> {
14    /// The values of the numerator and denominator polynomials
15    pub layer: PolynomialLayer<F, EF>,
16    /// The partial lagrange evaluation for the row variables
17    pub eq_row: Arc<Mle<EF>>,
18    /// The partial lagrange evaluation for the interaction variables
19    pub eq_interaction: Arc<Mle<EF>>,
20    /// The correction term for the eq polynomial.
21    pub eq_adjustment: EF,
22    /// The correction term for padding
23    pub padding_adjustment: EF,
24    /// The batching factor for the numerator and denominator claims.
25    pub lambda: EF,
26    /// The random point for the current GKR round.
27    pub point: Point<EF>,
28}
29
30/// A layer of the GKR circuit for the `LogupRoundPolynomial`.
31pub enum PolynomialLayer<F, EF> {
32    /// A layer of the GKR circuit.
33    CircuitLayer(LogUpGkrCpuLayer<F, EF>),
34    /// An interaction layer of the GKR circuit (`num_row_variables` == 1).
35    InteractionLayer(InteractionLayer<F, EF>),
36}
37
38impl<F: Field, EF: ExtensionField<F>> SumcheckPolyBase for LogupRoundPolynomial<F, EF> {
39    fn num_variables(&self) -> u32 {
40        self.eq_row.num_variables() + self.eq_interaction.num_variables()
41    }
42}
43
44impl<K: Field> ComponentPoly<K> for LogupRoundPolynomial<K, K> {
45    fn get_component_poly_evals(&self) -> Vec<K> {
46        match &self.layer {
47            PolynomialLayer::InteractionLayer(layer) => {
48                assert!(layer.numerator_0.guts().as_slice().len() == 1);
49                let numerator_0 = layer.numerator_0.guts().as_slice()[0];
50                let denominator_0 = layer.denominator_0.guts().as_slice()[0];
51                let numerator_1 = layer.numerator_1.guts().as_slice()[0];
52                let denominator_1 = layer.denominator_1.guts().as_slice()[0];
53                vec![numerator_0, denominator_0, numerator_1, denominator_1]
54            }
55            PolynomialLayer::CircuitLayer(_) => unreachable!(),
56        }
57    }
58}
59
60impl<K: Field> SumcheckPoly<K> for LogupRoundPolynomial<K, K> {
61    fn fix_last_variable(self, alpha: K) -> Self {
62        self.fix_t_variables(alpha, 1)
63    }
64
65    fn sum_as_poly_in_last_variable(&self, claim: Option<K>) -> UnivariatePolynomial<K> {
66        self.sum_as_poly_in_last_t_variables(claim, 1)
67    }
68}
69
70impl<K: ExtensionField<F>, F: Field> SumcheckPolyFirstRound<K> for LogupRoundPolynomial<F, K> {
71    type NextRoundPoly = LogupRoundPolynomial<K, K>;
72    #[allow(clippy::too_many_lines)]
73    fn fix_t_variables(mut self, alpha: K, t: usize) -> Self::NextRoundPoly {
74        assert!(t == 1);
75        // Remove the last coordinate from the point
76        let last_coordinate = self.point.remove_last_coordinate();
77        let padding_adjustment = self.padding_adjustment
78            * (last_coordinate * alpha + (K::one() - last_coordinate) * (K::one() - alpha));
79        match self.layer {
80            PolynomialLayer::InteractionLayer(layer) => {
81                let numerator_0 =
82                    Arc::new(layer.numerator_0.as_ref().fix_last_variable::<K>(alpha));
83                let denominator_0 =
84                    Arc::new(layer.denominator_0.as_ref().fix_last_variable::<K>(alpha));
85                let numerator_1 =
86                    Arc::new(layer.numerator_1.as_ref().fix_last_variable::<K>(alpha));
87                let denominator_1 =
88                    Arc::new(layer.denominator_1.as_ref().fix_last_variable::<K>(alpha));
89
90                let new_layer =
91                    InteractionLayer { numerator_0, denominator_0, numerator_1, denominator_1 };
92
93                let eq_interaction =
94                    Arc::new(self.eq_interaction.as_ref().fix_last_variable(alpha));
95
96                LogupRoundPolynomial {
97                    layer: PolynomialLayer::InteractionLayer(new_layer),
98                    eq_row: self.eq_row,
99                    eq_interaction,
100                    eq_adjustment: self.eq_adjustment,
101                    padding_adjustment,
102                    lambda: self.lambda,
103                    point: self.point,
104                }
105            }
106            PolynomialLayer::CircuitLayer(layer) => {
107                if layer.num_row_variables == 1 {
108                    let numerator_0: Vec<_> = layer
109                        .numerator_0
110                        .into_iter()
111                        .map(|mle| mle.fix_last_variable(alpha))
112                        .collect();
113                    let denominator_0: Vec<_> = layer
114                        .denominator_0
115                        .into_iter()
116                        .map(|mle| mle.fix_last_variable(alpha))
117                        .collect();
118                    let numerator_1: Vec<_> = layer
119                        .numerator_1
120                        .into_iter()
121                        .map(|mle| mle.fix_last_variable(alpha))
122                        .collect();
123                    let denominator_1: Vec<_> = layer
124                        .denominator_1
125                        .into_iter()
126                        .map(|mle| mle.fix_last_variable(alpha))
127                        .collect();
128
129                    let mut numerator_0_interactions: Vec<_> = numerator_0
130                        .into_iter()
131                        .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
132                        .collect();
133                    numerator_0_interactions
134                        .resize(1 << layer.num_interaction_variables, K::zero());
135
136                    let mut numerator_1_interactions: Vec<_> = numerator_1
137                        .into_iter()
138                        .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
139                        .collect();
140                    numerator_1_interactions
141                        .resize(1 << layer.num_interaction_variables, K::zero());
142
143                    let mut denominator_0_interactions: Vec<_> = denominator_0
144                        .into_iter()
145                        .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
146                        .collect();
147                    denominator_0_interactions
148                        .resize(1 << layer.num_interaction_variables, K::one());
149
150                    let mut denominator_1_interactions: Vec<_> = denominator_1
151                        .into_iter()
152                        .flat_map(|mle| mle.eval_at::<K>(&Point::from(vec![])).to_vec())
153                        .collect();
154                    denominator_1_interactions
155                        .resize(1 << layer.num_interaction_variables, K::one());
156
157                    let numerator_0_mle = Arc::new(Mle::from(numerator_0_interactions));
158                    let denominator_0_mle = Arc::new(Mle::from(denominator_0_interactions));
159                    let numerator_1_mle = Arc::new(Mle::from(numerator_1_interactions));
160                    let denominator_1_mle = Arc::new(Mle::from(denominator_1_interactions));
161
162                    let new_layer = InteractionLayer {
163                        numerator_0: numerator_0_mle,
164                        denominator_0: denominator_0_mle,
165                        numerator_1: numerator_1_mle,
166                        denominator_1: denominator_1_mle,
167                    };
168
169                    let eq_row = Arc::new(self.eq_row.as_ref().fix_last_variable(alpha));
170
171                    LogupRoundPolynomial {
172                        layer: PolynomialLayer::InteractionLayer(new_layer),
173                        eq_row,
174                        eq_interaction: self.eq_interaction,
175                        eq_adjustment: padding_adjustment,
176                        padding_adjustment: K::one(),
177                        lambda: self.lambda,
178                        point: self.point,
179                    }
180                } else {
181                    let numerator_0: Vec<_> = layer
182                        .numerator_0
183                        .into_iter()
184                        .map(|mle| mle.fix_last_variable(alpha))
185                        .collect();
186
187                    let denominator_0: Vec<_> = layer
188                        .denominator_0
189                        .into_iter()
190                        .map(|mle| mle.fix_last_variable(alpha))
191                        .collect();
192
193                    let numerator_1: Vec<_> = layer
194                        .numerator_1
195                        .into_iter()
196                        .map(|mle| mle.fix_last_variable(alpha))
197                        .collect();
198
199                    let denominator_1: Vec<_> = layer
200                        .denominator_1
201                        .into_iter()
202                        .map(|mle| mle.fix_last_variable(alpha))
203                        .collect();
204
205                    let eq_row = Arc::new(self.eq_row.as_ref().fix_last_variable(alpha));
206
207                    let new_layer = LogUpGkrCpuLayer {
208                        numerator_0,
209                        denominator_0,
210                        numerator_1,
211                        denominator_1,
212                        num_row_variables: layer.num_row_variables - 1,
213                        num_interaction_variables: layer.num_interaction_variables,
214                    };
215
216                    LogupRoundPolynomial {
217                        layer: PolynomialLayer::CircuitLayer(new_layer),
218                        eq_row,
219                        eq_interaction: self.eq_interaction,
220                        eq_adjustment: self.eq_adjustment,
221                        padding_adjustment,
222                        lambda: self.lambda,
223                        point: self.point,
224                    }
225                }
226            }
227        }
228    }
229
230    #[allow(clippy::too_many_lines)]
231    fn sum_as_poly_in_last_t_variables(
232        &self,
233        claim: Option<K>,
234        t: usize,
235    ) -> UnivariatePolynomial<K> {
236        assert!(t == 1);
237        let claim = claim.unwrap();
238
239        let (mut eval_zero, mut eval_half, eq_sum) = match &self.layer {
240            PolynomialLayer::InteractionLayer(layer) => {
241                let numerator_0 = layer.numerator_0.clone();
242                let numerator_1 = layer.numerator_1.clone();
243                let denominator_0 = layer.denominator_0.clone();
244                let denominator_1 = layer.denominator_1.clone();
245                let eq_interaction = self.eq_interaction.clone();
246                let lambda = self.lambda;
247                let numerator_eval_0 = numerator_0
248                    .guts()
249                    .as_slice()
250                    .par_iter()
251                    .step_by(2)
252                    .zip_eq(numerator_1.guts().as_slice().par_iter().step_by(2))
253                    .zip_eq(denominator_0.guts().as_slice().par_iter().step_by(2))
254                    .zip_eq(denominator_1.guts().as_slice().par_iter().step_by(2))
255                    .zip_eq(eq_interaction.guts().as_slice().par_iter().step_by(2))
256                    .map(|((((n0, n1), d0), d1), e)| *e * (*d0 * *n1 + *d1 * *n0))
257                    .sum::<K>();
258
259                let numerator_eval_half = numerator_0
260                    .guts()
261                    .as_slice()
262                    .par_chunks(2)
263                    .zip_eq(numerator_1.guts().as_slice().par_chunks(2))
264                    .zip_eq(denominator_0.guts().as_slice().par_chunks(2))
265                    .zip_eq(denominator_1.guts().as_slice().par_chunks(2))
266                    .zip_eq(eq_interaction.guts().as_slice().par_chunks(2))
267                    .map(|((((n0_chunk, n1_chunk), d0_chunk), d1_chunk), e_chunk)| {
268                        let n0_half = n0_chunk[0] + n0_chunk[1];
269                        let n1_half = n1_chunk[0] + n1_chunk[1];
270                        let d0_half = d0_chunk[0] + d0_chunk[1];
271                        let d1_half = d1_chunk[0] + d1_chunk[1];
272                        let e_half = e_chunk[0] + e_chunk[1];
273                        e_half * (d0_half * n1_half + d1_half * n0_half)
274                    })
275                    .sum::<K>();
276
277                let denominator_eval_0 = denominator_0
278                    .guts()
279                    .as_slice()
280                    .par_iter()
281                    .step_by(2)
282                    .zip_eq(denominator_1.guts().as_slice().par_iter().step_by(2))
283                    .zip_eq(eq_interaction.guts().as_slice().par_iter().step_by(2))
284                    .map(|((d0, d1), e)| *e * (*d0 * *d1))
285                    .sum::<K>();
286
287                let denominator_eval_half = denominator_0
288                    .guts()
289                    .as_slice()
290                    .par_chunks(2)
291                    .zip_eq(denominator_1.guts().as_slice().par_chunks(2))
292                    .zip_eq(eq_interaction.guts().as_slice().par_chunks(2))
293                    .map(|((d0_chunk, d1_chunk), e_chunk)| {
294                        let d0_half = d0_chunk[0] + d0_chunk[1];
295                        let d1_half = d1_chunk[0] + d1_chunk[1];
296                        let e_half = e_chunk[0] + e_chunk[1];
297                        e_half * (d0_half * d1_half)
298                    })
299                    .sum::<K>();
300
301                let eq_half_sum = eq_interaction
302                    .guts()
303                    .as_slice()
304                    .par_chunks(2)
305                    .map(|e_chunk| e_chunk[0] + e_chunk[1])
306                    .sum::<K>();
307
308                (
309                    lambda * numerator_eval_0 + denominator_eval_0,
310                    lambda * numerator_eval_half + denominator_eval_half,
311                    eq_half_sum,
312                )
313            }
314            PolynomialLayer::CircuitLayer(layer) => {
315                let numerator_0 = layer.numerator_0.clone();
316                let numerator_1 = layer.numerator_1.clone();
317                let denominator_0 = layer.denominator_0.clone();
318                let denominator_1 = layer.denominator_1.clone();
319                let eq_row = self.eq_row.clone();
320                // println!("eq_row.num_non_zero_entries(): {:?}", eq_row.num_non_zero_entries());
321                assert!(eq_row.num_non_zero_entries().is_multiple_of(2));
322                let eq_interaction = self.eq_interaction.clone();
323                let lambda = self.lambda;
324
325                let mut interaction_offset = 0;
326                let mut eval_0 = K::zero();
327                let mut eval_half = K::zero();
328                let mut eq_sum = K::zero();
329                for (numerator_0, numerator_1, denominator_0, denominator_1) in
330                    itertools::izip!(numerator_0, numerator_1, denominator_0, denominator_1)
331                {
332                    if let Some(inner) = numerator_0.inner() {
333                        assert!(numerator_0.num_variables() > 0);
334                        let numerator_1_inner = numerator_1.inner().as_ref().unwrap();
335                        // println!(
336                        //     "numerator_1_inner.num_variables(): {:?}",
337                        //     numerator_1_inner.num_variables()
338                        // );
339                        let denominator_0_inner = denominator_0.inner().as_ref().unwrap();
340                        let denominator_1_inner = denominator_1.inner().as_ref().unwrap();
341                        let (eval_0_chip, eval_half_chip, eq_sum_chip) =
342                            inner
343                                .guts()
344                                .as_slice()
345                                .par_chunks(2 * numerator_0.num_polynomials())
346                                .zip_eq(
347                                    numerator_1_inner
348                                        .guts()
349                                        .as_slice()
350                                        .par_chunks(2 * numerator_1_inner.num_polynomials()),
351                                )
352                                .zip_eq(
353                                    denominator_0_inner
354                                        .guts()
355                                        .as_slice()
356                                        .par_chunks(2 * denominator_0_inner.num_polynomials()),
357                                )
358                                .zip_eq(
359                                    denominator_1_inner
360                                        .guts()
361                                        .as_slice()
362                                        .par_chunks(2 * denominator_1_inner.num_polynomials()),
363                                )
364                                .zip(eq_row.guts().as_slice().par_chunks(2))
365                                .map(
366                                    |(
367                                        (((numer_0_row, numer_1_row), denom_0_row), denom_1_row),
368                                        eq_row_chunk,
369                                    )| {
370                                        let eq_interactions_chip = eq_interaction.guts().as_slice()
371                                            [interaction_offset
372                                                ..interaction_offset
373                                                    + numerator_0.num_polynomials()]
374                                            .par_iter();
375
376                                        let (numer_0_row_0, numer_0_row_1) =
377                                            numer_0_row.split_at(numerator_0.num_polynomials());
378                                        let (denom_0_row_0, denom_0_row_1) =
379                                            denom_0_row.split_at(denominator_0.num_polynomials());
380                                        let (denom_1_row_0, denom_1_row_1) =
381                                            denom_1_row.split_at(denominator_1.num_polynomials());
382                                        let (numer_1_row_0, numer_1_row_1) =
383                                            numer_1_row.split_at(numerator_1.num_polynomials());
384                                        let eq_row_0 = eq_row_chunk[0];
385                                        let eq_row_1 = eq_row_chunk[1];
386                                        if numer_0_row.len() == 2 * numerator_0.num_polynomials() {
387                                            let numerator_0_eval = numer_0_row_0
388                                                .par_iter()
389                                                .zip_eq(numer_1_row_0.par_iter())
390                                                .zip_eq(denom_0_row_0.par_iter())
391                                                .zip_eq(denom_1_row_0.par_iter())
392                                                .zip_eq(eq_interactions_chip.clone())
393                                                .map(|((((n0, n1), d0), d1), e)| {
394                                                    // assert_eq!(*e, K::one());
395                                                    *e * (*d0 * *n1 + *d1 * *n0)
396                                                })
397                                                .sum::<K>();
398                                            let denominator_0_eval = denom_0_row_0
399                                                .par_iter()
400                                                .zip_eq(denom_1_row_0.par_iter())
401                                                .zip_eq(eq_interactions_chip.clone())
402                                                .map(|((d0, d1), e)| *e * (*d0 * *d1))
403                                                .sum::<K>();
404                                            let numerator_half_eval = numer_0_row_0
405                                            .par_iter()
406                                            .zip_eq(numer_1_row_0.par_iter())
407                                            .zip_eq(denom_0_row_0.par_iter())
408                                            .zip_eq(denom_1_row_0.par_iter())
409                                            .zip_eq(numer_0_row_1.par_iter())
410                                            .zip_eq(numer_1_row_1.par_iter())
411                                            .zip_eq(denom_0_row_1.par_iter())
412                                            .zip_eq(denom_1_row_1.par_iter())
413                                            .zip_eq(eq_interactions_chip.clone())
414                                            .map(
415                                                |(((
416                                                    (
417                                                        (
418                                                            (((n0_0, n1_0), d0_0), d1_0),
419                                                            n0_1,
420                                                        ),
421                                                        n1_1,
422                                                    ),
423                                                    d0_1), d1_1),
424                                                    e,
425                                                )| {
426                                                    *e * ((*d0_0 + *d0_1) * (*n1_0 + *n1_1)
427                                                        + (*d1_0 + *d1_1) * (*n0_0 + *n0_1))
428                                                },
429                                            )
430                                            .sum::<K>();
431                                            let denominator_half_eval = denom_0_row_0
432                                                .par_iter()
433                                                .zip_eq(denom_1_row_0.par_iter())
434                                                .zip_eq(denom_0_row_1.par_iter())
435                                                .zip_eq(denom_1_row_1.par_iter())
436                                                .zip_eq(eq_interactions_chip.clone())
437                                                .map(|((((d0_0, d1_0), d0_1), d1_1), e)| {
438                                                    *e * ((*d0_0 + *d0_1) * (*d1_0 + *d1_1))
439                                                })
440                                                .sum::<K>();
441                                            let eq_interactions_chip_half = eq_interactions_chip
442                                                .map(|e| *e * (eq_row_0 + eq_row_1))
443                                                .sum::<K>();
444                                            (
445                                                (lambda * numerator_0_eval + denominator_0_eval)
446                                                    * eq_row_0,
447                                                (lambda * numerator_half_eval
448                                                    + denominator_half_eval)
449                                                    * (eq_row_0 + eq_row_1),
450                                                eq_interactions_chip_half,
451                                            )
452                                        } else {
453                                            let numerator_0_eval = numer_0_row_0
454                                                .par_iter()
455                                                .zip_eq(numer_1_row_0.par_iter())
456                                                .zip_eq(denom_0_row_0.par_iter())
457                                                .zip_eq(denom_1_row_0.par_iter())
458                                                .zip_eq(eq_interactions_chip.clone())
459                                                .map(|((((n0, n1), d0), d1), e)| {
460                                                    *e * (*d0 * *n1 + *d1 * *n0)
461                                                })
462                                                .sum::<K>();
463                                            let denominator_0_eval = denom_0_row_0
464                                                .par_iter()
465                                                .zip_eq(denom_1_row_0.par_iter())
466                                                .zip_eq(eq_interactions_chip.clone())
467                                                .map(|((d0, d1), e)| *e * (*d0 * *d1))
468                                                .sum::<K>();
469                                            let numerator_half_eval = numer_0_row_0
470                                                .par_iter()
471                                                .zip_eq(numer_1_row_0.par_iter())
472                                                .zip_eq(denom_0_row_0.par_iter())
473                                                .zip_eq(denom_1_row_0.par_iter())
474                                                .zip_eq(eq_interactions_chip.clone())
475                                                .map(|((((n0, n1), d0), d1), e)| {
476                                                    *e * ((*d0 + K::one()) * *n1
477                                                        + (*d1 + K::one()) * *n0)
478                                                })
479                                                .sum::<K>();
480                                            let denominator_half_eval = denom_0_row_0
481                                                .par_iter()
482                                                .zip_eq(denom_1_row_0.par_iter())
483                                                .zip_eq(eq_interactions_chip.clone())
484                                                .map(|((d0, d1), e)| {
485                                                    *e * ((*d0 + K::one()) * (*d1 + K::one()))
486                                                })
487                                                .sum::<K>();
488                                            let eq_interactions_chip_half = eq_interactions_chip
489                                                .map(|e| *e * (eq_row_0 + eq_row_1))
490                                                .sum::<K>();
491                                            (
492                                                (lambda * numerator_0_eval + denominator_0_eval)
493                                                    * eq_row_0,
494                                                (lambda * numerator_half_eval
495                                                    + denominator_half_eval)
496                                                    * (eq_row_0 + eq_row_1),
497                                                eq_interactions_chip_half,
498                                            )
499                                        }
500                                    },
501                                )
502                                .reduce(
503                                    || (K::zero(), K::zero(), K::zero()),
504                                    |(y_0_acc, y_half_acc, eq_sum_acc), (y_0, y_half, eq_sum)| {
505                                        (y_0_acc + y_0, y_half_acc + y_half, eq_sum_acc + eq_sum)
506                                    },
507                                );
508                        eval_0 += eval_0_chip;
509                        eval_half += eval_half_chip;
510                        eq_sum += eq_sum_chip;
511                    }
512                    interaction_offset += numerator_0.num_polynomials();
513                    // println!("interaction_offset: {:?}", interaction_offset);
514                }
515
516                (eval_0, eval_half, eq_sum)
517            }
518        };
519
520        // Correct the evaluations by the sum of the eq polynomial, which accounts for the
521        // contribution of padded row for the denominator expression
522        // `\Sum_i eq * denominator_0 * denominator_1`.
523        let eq_correction_term = self.padding_adjustment - eq_sum;
524        // println!("eq_correction_term: {:?}", eq_correction_term);
525        // The evaluation at zero just gets the eq correction term.
526        eval_zero += eq_correction_term * (K::one() - *self.point.last().unwrap());
527        // The evaluation at 1/2 gets the eq correction term times 4, since the denominators
528        // have a 1/2 in them for the rest of the evaluations (so we multiply by 2 twice).
529        eval_half += eq_correction_term * K::from_canonical_u16(4);
530
531        // Since the sumcheck polynomial is homogeneous of degree 3, we need to divide by
532        // 8 = 2^3 to account for the evaluations at 1/2 to be double their true value.
533        let eval_half = eval_half * K::from_canonical_u16(8).inverse();
534
535        let eval_zero = eval_zero * self.eq_adjustment;
536        let eval_half = eval_half * self.eq_adjustment;
537
538        // Get the root of the eq polynomial which gives an evaluation of zero.
539        let point_last = self.point.last().unwrap();
540        let b_const = (K::one() - *point_last) / (K::one() - point_last.double());
541
542        let eval_one = claim - eval_zero;
543
544        interpolate_univariate_polynomial(
545            &[
546                K::from_canonical_u16(0),
547                K::from_canonical_u16(1),
548                K::from_canonical_u16(2).inverse(),
549                b_const,
550            ],
551            &[eval_zero, eval_one, eval_half, K::zero()],
552        )
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use crate::{prove_gkr_round, GkrCircuitLayer, LogupGkrCpuTraceGenerator};
559
560    use super::*;
561    use itertools::Itertools;
562    use rand::{thread_rng, Rng};
563    use slop_algebra::{extension::BinomialExtensionField, AbstractField};
564    use slop_alloc::CpuBackend;
565
566    use slop_challenger::{FieldChallenger, IopCtx};
567    use slop_matrix::dense::RowMajorMatrix;
568    use slop_multilinear::{PaddedMle, Padding};
569    use slop_sumcheck::{partially_verify_sumcheck_proof, reduce_sumcheck_to_evaluation};
570    use slop_tensor::Tensor;
571    use sp1_primitives::SP1Field;
572
573    type EF = BinomialExtensionField<SP1Field, 4>;
574    type F = SP1Field;
575
576    fn random_layer(
577        rng: &mut impl Rng,
578        interaction_counts: &[usize],
579        num_rows: usize,
580        num_row_variables: usize,
581        num_interaction_variables: usize,
582    ) -> LogUpGkrCpuLayer<F, EF> {
583        let numerator_0 = interaction_counts
584            .iter()
585            .map(|count| {
586                let guts = Tensor::<F>::rand(rng, [num_rows, *count]);
587                Mle::new(guts)
588            })
589            .collect::<Vec<_>>();
590        let denominator_0 = interaction_counts
591            .iter()
592            .map(|count| {
593                let guts = Tensor::<EF>::rand(rng, [num_rows, *count]);
594                Mle::new(guts)
595            })
596            .collect::<Vec<_>>();
597        let numerator_1 = interaction_counts
598            .iter()
599            .map(|count| {
600                let guts = Tensor::<F>::rand(rng, [num_rows, *count]);
601                Mle::new(guts)
602            })
603            .collect::<Vec<_>>();
604        let denominator_1 = interaction_counts
605            .iter()
606            .map(|count| {
607                let guts = Tensor::<EF>::rand(rng, [num_rows, *count]);
608                Mle::new(guts)
609            })
610            .collect::<Vec<_>>();
611
612        let padded_numerator_0 = numerator_0
613            .iter()
614            .map(|mle| {
615                PaddedMle::padded_with_zeros(Arc::new(mle.clone()), num_row_variables as u32)
616            })
617            .collect::<Vec<_>>();
618
619        let padded_denominator_0 = denominator_0
620            .iter()
621            .map(|mle| {
622                let num_polys = mle.num_polynomials();
623                PaddedMle::padded(
624                    Arc::new(mle.clone()),
625                    num_row_variables as u32,
626                    Padding::Constant((EF::one(), num_polys, CpuBackend)),
627                )
628            })
629            .collect::<Vec<_>>();
630
631        let padded_numerator_1 = numerator_1
632            .iter()
633            .map(|mle| {
634                PaddedMle::padded_with_zeros(Arc::new(mle.clone()), num_row_variables as u32)
635            })
636            .collect::<Vec<_>>();
637        let padded_denominator_1 = denominator_1
638            .iter()
639            .map(|mle| {
640                let num_polys = mle.num_polynomials();
641                PaddedMle::padded(
642                    Arc::new(mle.clone()),
643                    num_row_variables as u32,
644                    Padding::Constant((EF::one(), num_polys, CpuBackend)),
645                )
646            })
647            .collect::<Vec<_>>();
648
649        LogUpGkrCpuLayer {
650            numerator_0: padded_numerator_0,
651            denominator_0: padded_denominator_0,
652            numerator_1: padded_numerator_1,
653            denominator_1: padded_denominator_1,
654            num_row_variables,
655            num_interaction_variables,
656        }
657    }
658
659    #[test]
660    #[allow(clippy::too_many_lines)]
661    fn test_logup_poly_fix_last_variable() {
662        let mut rng = thread_rng();
663        let interaction_counts = vec![1];
664        let num_rows: usize = 4;
665        let num_row_variables = 2;
666        let num_interaction_variables =
667            interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
668        let layer = random_layer(
669            &mut rng,
670            &interaction_counts,
671            num_rows,
672            num_row_variables as usize,
673            num_interaction_variables as usize,
674        );
675
676        let poly_point = Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
677        let (interaction_point, row_point) =
678            poly_point.split_at(num_interaction_variables as usize);
679
680        let random_point =
681            Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
682        let (interaction_random_point, row_random_point) =
683            random_point.split_at(num_interaction_variables as usize);
684
685        let lambda = rng.gen::<EF>();
686        let eq_row = Mle::partial_lagrange(&row_point);
687        let eq_interaction = Mle::partial_lagrange(&interaction_point);
688
689        let first_polynomial = LogupRoundPolynomial {
690            layer: PolynomialLayer::CircuitLayer(layer),
691            eq_row: Arc::new(eq_row),
692            eq_interaction: Arc::new(eq_interaction),
693            eq_adjustment: EF::one(),
694            padding_adjustment: EF::one(),
695            lambda,
696            point: poly_point,
697        };
698
699        let PolynomialLayer::CircuitLayer(layer) = &first_polynomial.layer else {
700            panic!("first polynomial is not a circuit layer");
701        };
702
703        let mut numerator_0_interactions: Vec<EF> = layer
704            .numerator_0
705            .iter()
706            .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
707            .collect();
708        numerator_0_interactions.resize(1 << layer.num_interaction_variables, EF::zero());
709
710        let mut numerator_1_interactions: Vec<EF> = layer
711            .numerator_1
712            .iter()
713            .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
714            .collect();
715        numerator_1_interactions.resize(1 << layer.num_interaction_variables, EF::zero());
716
717        let mut denominator_0_interactions: Vec<EF> = layer
718            .denominator_0
719            .iter()
720            .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
721            .collect();
722        denominator_0_interactions.resize(1 << layer.num_interaction_variables, EF::one());
723
724        let mut denominator_1_interactions: Vec<EF> = layer
725            .denominator_1
726            .iter()
727            .flat_map(|mle| mle.eval_at::<EF>(&row_random_point).to_vec())
728            .collect();
729        denominator_1_interactions.resize(1 << layer.num_interaction_variables, EF::one());
730
731        // Fix last variable until we get to interaction layer
732        let mut round_polynomial =
733            first_polynomial.fix_t_variables(*row_random_point.last().unwrap(), 1);
734
735        for alpha in row_random_point.iter().rev().skip(1) {
736            round_polynomial = round_polynomial.fix_t_variables(*alpha, 1);
737        }
738
739        let PolynomialLayer::InteractionLayer(interaction_layer) = &round_polynomial.layer else {
740            panic!("round polynomial is not an interaction layer");
741        };
742
743        // Check expected mle against actual mle for first interaction layer
744        for (i, numerator_0_interaction) in numerator_0_interactions.iter().enumerate() {
745            assert_eq!(
746                *numerator_0_interaction,
747                interaction_layer.numerator_0.guts().as_slice()[i]
748            );
749        }
750        for (i, numerator_1_interaction) in numerator_1_interactions.iter().enumerate() {
751            assert_eq!(
752                *numerator_1_interaction,
753                interaction_layer.numerator_1.guts().as_slice()[i]
754            );
755        }
756        for (i, denominator_0_interaction) in denominator_0_interactions.iter().enumerate() {
757            assert_eq!(
758                *denominator_0_interaction,
759                interaction_layer.denominator_0.guts().as_slice()[i]
760            );
761        }
762        for (i, denominator_1_interaction) in denominator_1_interactions.iter().enumerate() {
763            assert_eq!(
764                *denominator_1_interaction,
765                interaction_layer.denominator_1.guts().as_slice()[i]
766            );
767        }
768
769        // Get the expected evaluations
770        let numerator_0_eval = interaction_layer.numerator_0.eval_at(&interaction_random_point)[0];
771        let numerator_1_eval = interaction_layer.numerator_1.eval_at(&interaction_random_point)[0];
772        let denominator_0_eval =
773            interaction_layer.denominator_0.eval_at(&interaction_random_point)[0];
774        let denominator_1_eval =
775            interaction_layer.denominator_1.eval_at(&interaction_random_point)[0];
776
777        // Proceed with rest of interaction layers.
778        for alpha in interaction_random_point.iter().rev() {
779            round_polynomial = round_polynomial.fix_t_variables(*alpha, 1);
780        }
781
782        let [n0, d0, n1, d1] = round_polynomial.get_component_poly_evals().try_into().unwrap();
783
784        assert_eq!(numerator_0_eval, n0);
785        assert_eq!(numerator_1_eval, n1);
786        assert_eq!(denominator_0_eval, d0);
787        assert_eq!(denominator_1_eval, d1);
788    }
789
790    #[test]
791    #[allow(clippy::too_many_lines)]
792    fn test_logup_poly_sumcheck_circuit_layer() {
793        type GC = sp1_primitives::SP1GlobalContext;
794        let mut rng = thread_rng();
795
796        let get_challenger = move || GC::default_challenger();
797
798        let interaction_counts = vec![4, 5, 6];
799        let num_rows: usize = 8;
800        let num_row_variables = 4;
801
802        let num_interaction_variables =
803            interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
804        let layer = random_layer(
805            &mut rng,
806            &interaction_counts,
807            num_rows,
808            num_row_variables as usize,
809            num_interaction_variables as usize,
810        );
811
812        let poly_point = Point::<EF>::rand(&mut rng, num_row_variables + num_interaction_variables);
813        let (interaction_point, row_point) =
814            poly_point.split_at(num_interaction_variables as usize);
815
816        let eq_row = Mle::partial_lagrange(&row_point);
817        let eq_interaction = Mle::partial_lagrange(&interaction_point);
818
819        let numerator_0 = layer.numerator_0.clone();
820        let numerator_1 = layer.numerator_1.clone();
821        let denominator_0 = layer.denominator_0.clone();
822        let denominator_1 = layer.denominator_1.clone();
823        let lambda = rng.gen::<EF>();
824
825        let round_polynomial = LogupRoundPolynomial {
826            layer: PolynomialLayer::CircuitLayer(layer),
827            eq_row: Arc::new(eq_row),
828            eq_interaction: Arc::new(eq_interaction),
829            eq_adjustment: EF::one(),
830            padding_adjustment: EF::one(),
831            lambda,
832            point: poly_point.clone(),
833        };
834
835        let total_eq = Mle::partial_lagrange(&poly_point);
836
837        let total_eq_guts = total_eq.guts().as_slice().to_vec().clone();
838
839        let claim = {
840            let mut offset = 0;
841            let real_claim = numerator_0
842                .iter()
843                .zip_eq(numerator_1.iter())
844                .zip_eq(denominator_0.iter())
845                .zip_eq(denominator_1.iter())
846                .map(|(((n_0, n_1), d_0), d_1)| {
847                    // Add padded rows to n0 so that num_rows is next power of 2
848                    let num_padding = vec![
849                        F::zero();
850                        ((1 << num_row_variables) - num_rows)
851                            * n_0.num_polynomials()
852                    ];
853                    let den_padding = vec![
854                        EF::one();
855                        ((1 << num_row_variables) - num_rows)
856                            * d_0.num_polynomials()
857                    ];
858
859                    let padded_n0 = n_0
860                        .inner()
861                        .as_ref()
862                        .unwrap()
863                        .guts()
864                        .as_slice()
865                        .iter()
866                        .copied()
867                        .chain(num_padding.iter().copied())
868                        .collect::<Vec<_>>();
869                    let padded_n1 = n_1
870                        .inner()
871                        .as_ref()
872                        .unwrap()
873                        .guts()
874                        .as_slice()
875                        .iter()
876                        .copied()
877                        .chain(num_padding.iter().copied())
878                        .collect::<Vec<_>>();
879                    let padded_d0 = d_0
880                        .inner()
881                        .as_ref()
882                        .unwrap()
883                        .guts()
884                        .as_slice()
885                        .iter()
886                        .copied()
887                        .chain(den_padding.iter().copied())
888                        .collect::<Vec<_>>();
889                    let padded_d1 = d_1
890                        .inner()
891                        .as_ref()
892                        .unwrap()
893                        .guts()
894                        .as_slice()
895                        .iter()
896                        .copied()
897                        .chain(den_padding.iter().copied())
898                        .collect::<Vec<_>>();
899                    let padded_d0 =
900                        Mle::from(RowMajorMatrix::new(padded_d0, d_0.num_polynomials()));
901                    let padded_d1 =
902                        Mle::from(RowMajorMatrix::new(padded_d1, d_1.num_polynomials()));
903                    let padded_n0 =
904                        Mle::from(RowMajorMatrix::new(padded_n0, n_0.num_polynomials()));
905                    let padded_n1 =
906                        Mle::from(RowMajorMatrix::new(padded_n1, n_1.num_polynomials()));
907
908                    let result = padded_n0
909                        .guts()
910                        .transpose()
911                        .as_slice()
912                        .iter()
913                        .zip_eq(padded_n1.guts().transpose().as_slice().iter())
914                        .zip_eq(padded_d0.guts().transpose().as_slice().iter())
915                        .zip_eq(padded_d1.guts().transpose().as_slice().iter())
916                        .zip(total_eq_guts.iter().skip(offset))
917                        .map(|((((n_0, n_1), d_0), d_1), e)| {
918                            let numerator_eval = *d_1 * *n_0 + *d_0 * *n_1;
919                            let denominator_eval = *d_0 * *d_1;
920                            *e * (numerator_eval * lambda + denominator_eval)
921                        })
922                        .sum::<EF>();
923
924                    offset += padded_n0.guts().as_slice().len();
925                    result
926                })
927                .sum::<EF>();
928            let remaining_eq = total_eq_guts.iter().copied().skip(offset).sum::<EF>();
929            real_claim + remaining_eq
930        };
931
932        let mut challenger = get_challenger();
933        let (proof, evals) = reduce_sumcheck_to_evaluation(
934            vec![round_polynomial],
935            &mut challenger,
936            vec![claim],
937            1,
938            EF::one(),
939        );
940
941        let mut challenger = get_challenger();
942        partially_verify_sumcheck_proof(
943            &proof,
944            &mut challenger,
945            (num_row_variables + num_interaction_variables) as usize,
946            3,
947        )
948        .unwrap();
949
950        let (point, expected_final_eval) = proof.point_and_eval;
951
952        // Assert that the point has the expected dimension.
953        assert_eq!(point.dimension() as u32, num_row_variables + num_interaction_variables);
954
955        // Calculate the expected evaluations at the point.
956        let [evals] = evals.try_into().unwrap();
957        assert_eq!(evals.len(), 4);
958        let [n_0, d_0, n_1, d_1] = evals.try_into().unwrap();
959
960        let eq_eval = Mle::full_lagrange_eval(&poly_point, &point);
961
962        let expected_numerator_eval = n_0 * d_1 + n_1 * d_0;
963        let expected_denominator_eval = d_0 * d_1;
964        let eval = expected_numerator_eval * lambda + expected_denominator_eval;
965        let final_eval = eq_eval * eval;
966
967        // Assert that the final eval is correct.
968        assert_eq!(final_eval, expected_final_eval);
969    }
970
971    #[test]
972    #[allow(clippy::too_many_lines)]
973    fn test_logup_gkr_circuit_transition() {
974        type TraceGenerator = LogupGkrCpuTraceGenerator<SP1Field, EF, ()>;
975        let mut rng = thread_rng();
976
977        let trace_generator = TraceGenerator::default();
978
979        let interaction_counts = vec![4, 5, 6];
980        let num_rows: usize = 8;
981        let num_row_variables = 4;
982        let num_interaction_variables =
983            interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
984        let layer = random_layer(
985            &mut rng,
986            &interaction_counts,
987            num_rows,
988            num_row_variables as usize,
989            num_interaction_variables as usize,
990        );
991        let next_layer = trace_generator.layer_transition(&layer);
992
993        let curr_numerator_0 = layer.numerator_0;
994        let curr_numerator_1 = layer.numerator_1;
995        let curr_denominator_0 = layer.denominator_0;
996        let curr_denominator_1 = layer.denominator_1;
997
998        let next_numerator_0 = next_layer.numerator_0;
999        let next_numerator_1 = next_layer.numerator_1;
1000        let next_denominator_0 = next_layer.denominator_0;
1001        let next_denominator_1 = next_layer.denominator_1;
1002
1003        for (next_n0, next_n1, next_d0, next_d1, curr_n0, curr_n1, curr_d0, curr_d1) in itertools::izip!(
1004            next_numerator_0.iter(),
1005            next_numerator_1.iter(),
1006            next_denominator_0.iter(),
1007            next_denominator_1.iter(),
1008            curr_numerator_0.iter(),
1009            curr_numerator_1.iter(),
1010            curr_denominator_0.iter(),
1011            curr_denominator_1.iter()
1012        ) {
1013            let next_n1_inner = next_n1.inner().as_ref().unwrap();
1014            let next_n0_inner = next_n0.inner().as_ref().unwrap();
1015            let next_d0_inner = next_d0.inner().as_ref().unwrap();
1016            let next_d1_inner = next_d1.inner().as_ref().unwrap();
1017            let curr_n0_inner = curr_n0.inner().as_ref().unwrap();
1018            let curr_n1_inner = curr_n1.inner().as_ref().unwrap();
1019            let curr_d0_inner = curr_d0.inner().as_ref().unwrap();
1020            let curr_d1_inner = curr_d1.inner().as_ref().unwrap();
1021            let _ = next_n0_inner
1022                .guts()
1023                .transpose()
1024                .as_slice()
1025                .chunks(next_n0.num_real_entries())
1026                .zip_eq(
1027                    next_n1_inner.guts().transpose().as_slice().chunks(next_n1.num_real_entries()),
1028                )
1029                .zip_eq(
1030                    curr_n0_inner
1031                        .guts()
1032                        .transpose()
1033                        .as_slice()
1034                        .chunks(curr_n0.num_real_entries())
1035                        .zip_eq(
1036                            curr_n1_inner
1037                                .guts()
1038                                .transpose()
1039                                .as_slice()
1040                                .chunks(curr_n1.num_real_entries()),
1041                        ),
1042                )
1043                .zip_eq(
1044                    curr_d0_inner
1045                        .guts()
1046                        .transpose()
1047                        .as_slice()
1048                        .chunks(curr_d0.num_real_entries())
1049                        .zip_eq(
1050                            curr_d1_inner
1051                                .guts()
1052                                .transpose()
1053                                .as_slice()
1054                                .chunks(curr_d1.num_real_entries()),
1055                        ),
1056                )
1057                .map(
1058                    |(
1059                        ((n0_col, n1_col), (curr_n0_col, curr_n1_col)),
1060                        (curr_d0_col, curr_d1_col),
1061                    )| {
1062                        let next_n = n0_col.iter().interleave(n1_col.iter()).collect::<Vec<_>>();
1063                        for (
1064                            i,
1065                            ((((next_n_val, curr_n0_val), curr_n1_val), curr_d0_val), curr_d1_val),
1066                        ) in next_n
1067                            .iter()
1068                            .copied()
1069                            .zip_eq(curr_n0_col.iter())
1070                            .zip_eq(curr_n1_col.iter())
1071                            .zip_eq(curr_d0_col.iter())
1072                            .zip_eq(curr_d1_col.iter())
1073                            .enumerate()
1074                        {
1075                            assert_eq!(
1076                                *next_n_val,
1077                                *curr_d1_val * *curr_n0_val + *curr_d0_val * *curr_n1_val,
1078                                "failed at index {i}"
1079                            );
1080                        }
1081                    },
1082                );
1083            let _ = next_d0_inner
1084                .guts()
1085                .transpose()
1086                .as_slice()
1087                .chunks(next_d0.num_real_entries())
1088                .zip_eq(
1089                    next_d1_inner.guts().transpose().as_slice().chunks(next_d1.num_real_entries()),
1090                )
1091                .zip_eq(
1092                    curr_d0_inner
1093                        .guts()
1094                        .transpose()
1095                        .as_slice()
1096                        .chunks(curr_d0.num_real_entries())
1097                        .zip_eq(
1098                            curr_d1_inner
1099                                .guts()
1100                                .transpose()
1101                                .as_slice()
1102                                .chunks(curr_d1.num_real_entries()),
1103                        ),
1104                )
1105                .map(|((next_d0_col, next_d1_col), (curr_d0_col, curr_d1_col))| {
1106                    let next_d =
1107                        next_d0_col.iter().interleave(next_d1_col.iter()).collect::<Vec<_>>();
1108                    for (i, ((next_d_val, curr_d0_val), curr_d1_val)) in next_d
1109                        .iter()
1110                        .copied()
1111                        .zip_eq(curr_d0_col.iter())
1112                        .zip_eq(curr_d1_col.iter())
1113                        .enumerate()
1114                    {
1115                        assert_eq!(*next_d_val, *curr_d0_val * *curr_d1_val, "failed at index {i}");
1116                    }
1117                });
1118        }
1119    }
1120
1121    #[test]
1122    fn test_logup_gkr_round_prover() {
1123        type GC = sp1_primitives::SP1GlobalContext;
1124        type TraceGenerator = LogupGkrCpuTraceGenerator<SP1Field, EF, ()>;
1125        let get_challenger = move || GC::default_challenger();
1126        let trace_generator = TraceGenerator::default();
1127
1128        let mut rng = thread_rng();
1129
1130        let interaction_counts = vec![4, 5, 6];
1131        let num_interaction_variables =
1132            interaction_counts.iter().sum::<usize>().next_power_of_two().ilog2();
1133        let num_rows: usize = 32;
1134        let num_row_variables = 7;
1135        let input_layer = random_layer(
1136            &mut rng,
1137            &interaction_counts,
1138            num_rows,
1139            num_row_variables as usize,
1140            num_interaction_variables as usize,
1141        );
1142
1143        let first_eval_point = Point::<EF>::rand(&mut rng, num_interaction_variables + 1);
1144
1145        let layer = GkrCircuitLayer::FirstLayer(input_layer);
1146
1147        let mut layers = vec![layer];
1148        for _ in 0..num_row_variables - 1 {
1149            let next_layer = match layers.last().unwrap() {
1150                GkrCircuitLayer::Layer(layer) => trace_generator.layer_transition(layer),
1151                GkrCircuitLayer::FirstLayer(layer) => trace_generator.layer_transition(layer),
1152            };
1153            layers.push(GkrCircuitLayer::Layer(next_layer));
1154        }
1155        layers.reverse();
1156
1157        let GkrCircuitLayer::Layer(first_layer) = layers.first().unwrap() else {
1158            panic!("first layer not correct");
1159        };
1160
1161        let output = trace_generator.extract_outputs(first_layer);
1162        assert_eq!(output.numerator.num_variables(), num_interaction_variables + 1);
1163        assert_eq!(output.denominator.num_variables(), num_interaction_variables + 1);
1164
1165        let first_numerator_eval = output.numerator.eval_at(&first_eval_point)[0];
1166        let first_denominator_eval = output.denominator.eval_at(&first_eval_point)[0];
1167
1168        let mut challenger = get_challenger();
1169        let mut round_proofs = Vec::new();
1170        let mut numerator_eval = first_numerator_eval;
1171        let mut denominator_eval = first_denominator_eval;
1172        let mut eval_point = first_eval_point.clone();
1173
1174        for layer in layers {
1175            let round_proof = prove_gkr_round(
1176                layer,
1177                &eval_point,
1178                numerator_eval,
1179                denominator_eval,
1180                &mut challenger,
1181            );
1182            // Observe the prover message.
1183            challenger.observe_ext_element(round_proof.numerator_0);
1184            challenger.observe_ext_element(round_proof.denominator_0);
1185            challenger.observe_ext_element(round_proof.numerator_1);
1186            challenger.observe_ext_element(round_proof.denominator_1);
1187            // Get the evaluation point for the claims.
1188            eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
1189            // Sample the last coordinate.
1190            let last_coordinate = challenger.sample_ext_element::<EF>();
1191
1192            // Compute the evaluation of the numerator and denominator at the last coordinate.
1193            numerator_eval = round_proof.numerator_0
1194                + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
1195            denominator_eval = round_proof.denominator_0
1196                + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
1197            eval_point.add_dimension_back(last_coordinate);
1198            // Add the round proof to the total
1199            round_proofs.push(round_proof);
1200        }
1201
1202        // Follow the GKR protocol layer by layer.
1203        let mut challenger = get_challenger();
1204        let mut numerator_eval = first_numerator_eval;
1205        let mut denominator_eval = first_denominator_eval;
1206        let mut eval_point = first_eval_point;
1207        for (i, round_proof) in round_proofs.iter().enumerate() {
1208            // Get the batching challenge for combining the claims.
1209            let lambda = challenger.sample_ext_element::<EF>();
1210            // Check that the claimed sum is consistent with the previous round values.
1211            let expected_claim = numerator_eval * lambda + denominator_eval;
1212            assert_eq!(round_proof.sumcheck_proof.claimed_sum, expected_claim);
1213
1214            // Verify the sumcheck proof.
1215            partially_verify_sumcheck_proof(
1216                &round_proof.sumcheck_proof,
1217                &mut challenger,
1218                i + num_interaction_variables as usize + 1,
1219                3,
1220            )
1221            .unwrap();
1222
1223            // Verify that the evaluation claim is consistent with the prover messages.
1224            let (point, final_eval) = round_proof.sumcheck_proof.point_and_eval.clone();
1225            let eq_eval = Mle::full_lagrange_eval(&point, &eval_point);
1226            let numerator_sumcheck_eval = round_proof.numerator_0 * round_proof.denominator_1
1227                + round_proof.numerator_1 * round_proof.denominator_0;
1228            let denominator_sumcheck_eval = round_proof.denominator_0 * round_proof.denominator_1;
1229            let expected_final_eval =
1230                eq_eval * (numerator_sumcheck_eval * lambda + denominator_sumcheck_eval);
1231
1232            assert_eq!(final_eval, expected_final_eval, "failed at index {i}");
1233
1234            // Observe the prover message.
1235            challenger.observe_ext_element(round_proof.numerator_0);
1236            challenger.observe_ext_element(round_proof.denominator_0);
1237            challenger.observe_ext_element(round_proof.numerator_1);
1238            challenger.observe_ext_element(round_proof.denominator_1);
1239
1240            // Get the evaluation point for the claims.
1241            eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
1242
1243            // Sample the last coordinate and add to the point.
1244            let last_coordinate = challenger.sample_ext_element::<EF>();
1245            eval_point.add_dimension_back(last_coordinate);
1246            // Update the evaluation of the numerator and denominator at the last coordinate.
1247            numerator_eval = round_proof.numerator_0
1248                + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
1249            denominator_eval = round_proof.denominator_0
1250                + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
1251        }
1252    }
1253}