Skip to main content

sp1_hypercube/prover/zerocheck/
sum_as_poly.rs

1use std::{
2    marker::PhantomData,
3    ops::{Add, Mul, Sub},
4    sync::Arc,
5};
6
7use itertools::Itertools;
8use rayon::iter::{ParallelBridge, ParallelIterator};
9use serde::{Deserialize, Serialize};
10use slop_air::Air;
11use slop_algebra::{
12    interpolate_univariate_polynomial, AbstractExtensionField, ExtensionField, Field,
13    UnivariatePolynomial,
14};
15use slop_matrix::dense::RowMajorMatrixView;
16use slop_multilinear::{Mle, PaddedMle};
17use slop_sumcheck::SumcheckPolyBase;
18
19use crate::{air::MachineAir, ConstraintSumcheckFolder};
20use slop_alloc::HasBackend;
21
22use super::ZeroCheckPoly;
23
24/// Zerocheck data for the CPU backend.
25#[derive(Clone)]
26pub struct ZerocheckCpuProver<F, EF, A> {
27    /// The AIR that contains the constraint polynomial.
28    air: Arc<A>,
29    /// The public values.
30    public_values: Arc<Vec<F>>,
31    /// The powers of alpha.
32    powers_of_alpha: Arc<Vec<EF>>,
33    gkr_powers: Arc<Vec<EF>>,
34}
35
36impl<F, EF, A> ZerocheckCpuProver<F, EF, A> {
37    /// Creates a new `ZerocheckAirData`.
38    pub fn new(
39        air: Arc<A>,
40        public_values: Arc<Vec<F>>,
41        powers_of_alpha: Arc<Vec<EF>>,
42        gkr_powers: Arc<Vec<EF>>,
43    ) -> Self {
44        Self { air, public_values, powers_of_alpha, gkr_powers }
45    }
46}
47
48impl<F, EF, A> ZerocheckCpuProver<F, EF, A>
49where
50    F: Field,
51    EF: ExtensionField<F>,
52{
53    pub(crate) fn sum_as_poly_in_last_variable<K, const IS_FIRST_ROUND: bool>(
54        &self,
55        partial_lagrange: &Mle<EF>,
56        preprocessed_values: Option<&PaddedMle<K>>,
57        main_values: &PaddedMle<K>,
58    ) -> (EF, EF, EF)
59    where
60        K: ExtensionField<F>,
61        EF: ExtensionField<K>,
62        A: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
63    {
64        let air = self.air.clone();
65        let public_values = self.public_values.clone();
66        let powers_of_alpha = self.powers_of_alpha.clone();
67        let gkr_powers = self.gkr_powers.clone();
68        {
69            let num_non_padded_terms = main_values.num_real_entries().div_ceil(2);
70            let eq_chunk_size = std::cmp::max(num_non_padded_terms / num_cpus::get(), 1);
71            let values_chunk_size = eq_chunk_size * 2;
72
73            let eq_guts = partial_lagrange.guts().as_buffer().as_slice();
74
75            let num_main_columns = main_values.num_polynomials();
76            let num_preprocessed_columns =
77                preprocessed_values.map_or(0, slop_multilinear::PaddedMle::num_polynomials);
78
79            let main_values = main_values.inner().as_ref().unwrap().guts().as_buffer().as_slice();
80            let has_preprocessed_values = preprocessed_values.is_some();
81            let preprocessed_values = preprocessed_values.as_ref().map_or([].as_slice(), |p| {
82                p.inner().as_ref().unwrap().guts().as_buffer().as_slice()
83            });
84
85            // Handle the case when the zerocheck polynomial has non-padded variables.
86            let eq_guts = eq_guts[0..num_non_padded_terms].to_vec();
87
88            let cumul_ys = eq_guts
89                .chunks(eq_chunk_size)
90                .zip(main_values.chunks(values_chunk_size * num_main_columns))
91                .enumerate()
92                .par_bridge()
93                .map(|(i, (eq_chunk, main_chunk))| {
94                    // Evaluate the constraint polynomial at the points 0, 2, and 4, and
95                    // add the results to the y_0, y_2, and y_4 accumulators.
96                    let mut cumul_y_0 = EF::zero();
97                    let mut cumul_y_2 = EF::zero();
98                    let mut cumul_y_4 = EF::zero();
99
100                    let mut main_values_0 = vec![K::zero(); num_main_columns];
101                    let mut main_values_2 = vec![K::zero(); num_main_columns];
102                    let mut main_values_4 = vec![K::zero(); num_main_columns];
103
104                    let mut preprocessed_values_0 = vec![K::zero(); num_preprocessed_columns];
105                    let mut preprocessed_values_2 = vec![K::zero(); num_preprocessed_columns];
106                    let mut preprocessed_values_4 = vec![K::zero(); num_preprocessed_columns];
107
108                    for (j, (eq, main_row)) in
109                        eq_chunk.iter().zip(main_chunk.chunks(num_main_columns * 2)).enumerate()
110                    {
111                        let main_row_0 = &main_row[0..num_main_columns];
112                        let main_row_1 = if main_row.len() == 2 * num_main_columns {
113                            &main_row[num_main_columns..num_main_columns * 2]
114                        } else {
115                            // Provide a dummy row if there is an odd number of rows.
116                            &vec![K::zero(); num_main_columns]
117                        };
118
119                        interpolate_last_var_non_padded_values::<K, IS_FIRST_ROUND>(
120                            main_row_0,
121                            main_row_1,
122                            &mut main_values_0,
123                            &mut main_values_2,
124                            &mut main_values_4,
125                        );
126
127                        if has_preprocessed_values {
128                            let preprocess_chunk_size =
129                                values_chunk_size * num_preprocessed_columns;
130                            let preprocessed_row_0_start_idx =
131                                i * preprocess_chunk_size + 2 * j * num_preprocessed_columns;
132                            let preprocessed_row_0 = &preprocessed_values
133                                [preprocessed_row_0_start_idx
134                                    ..preprocessed_row_0_start_idx + num_preprocessed_columns];
135                            let preprocessed_row_1_start_idx =
136                                preprocessed_row_0_start_idx + num_preprocessed_columns;
137                            let preprocessed_row_1 =
138                                if preprocessed_values.len() != preprocessed_row_1_start_idx {
139                                    &preprocessed_values[preprocessed_row_1_start_idx
140                                        ..preprocessed_row_1_start_idx + num_preprocessed_columns]
141                                } else {
142                                    // Provide padding values if there is an odd number of rows.
143                                    &vec![K::zero(); num_preprocessed_columns]
144                                };
145
146                            interpolate_last_var_non_padded_values::<K, IS_FIRST_ROUND>(
147                                preprocessed_row_0,
148                                preprocessed_row_1,
149                                &mut preprocessed_values_0,
150                                &mut preprocessed_values_2,
151                                &mut preprocessed_values_4,
152                            );
153                        }
154
155                        increment_y_values::<K, F, EF, A, IS_FIRST_ROUND>(
156                            &public_values,
157                            &powers_of_alpha,
158                            &air,
159                            &mut cumul_y_0,
160                            &mut cumul_y_2,
161                            &mut cumul_y_4,
162                            &preprocessed_values_0,
163                            &main_values_0,
164                            &preprocessed_values_2,
165                            &main_values_2,
166                            &preprocessed_values_4,
167                            &main_values_4,
168                            &gkr_powers,
169                            *eq,
170                        );
171                    }
172                    (cumul_y_0, cumul_y_2, cumul_y_4)
173                })
174                .collect::<Vec<_>>();
175
176            cumul_ys.into_iter().fold(
177                (EF::zero(), EF::zero(), EF::zero()),
178                |(y_0, y_2, y_4), (y_0_i, y_2_i, y_4_i)| (y_0 + y_0_i, y_2 + y_2_i, y_4 + y_4_i),
179            )
180        }
181    }
182}
183
184/// This function will calculate the univariate polynomial where all variables other than the last
185/// are summed on the boolean hypercube and the last variable is left as a free variable.
186/// TODO:  Add flexibility to support degree 2 and degree 3 constraint polynomials.
187pub fn zerocheck_sum_as_poly_in_last_variable<
188    K: ExtensionField<F>,
189    F: Field,
190    EF: ExtensionField<F> + ExtensionField<K> + ExtensionField<F> + AbstractExtensionField<K>,
191    AirData,
192    const IS_FIRST_ROUND: bool,
193>(
194    poly: &ZeroCheckPoly<K, F, EF, AirData>,
195    claim: Option<EF>,
196) -> UnivariatePolynomial<EF>
197where
198    AirData: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
199{
200    let num_real_entries = poly.main_columns.num_real_entries();
201    if num_real_entries == 0 {
202        // NOTE: We hard-code the degree of the zerocheck to be three here. This is important to get
203        // the correct shape of a dummy proof.
204        return UnivariatePolynomial::zero(4);
205    }
206
207    let claim = claim.expect("claim must be provided");
208
209    let (rest_point_host, last) = poly.zeta.split_at(poly.zeta.dimension() - 1);
210    let last = *last[0];
211
212    // TODO:  Optimization of computing this once per zerocheck sumcheck.
213    let partial_lagrange: Mle<EF> = Mle::partial_lagrange(&rest_point_host);
214    let partial_lagrange = Arc::new(partial_lagrange);
215
216    // For the first round, we know that at point 0 and 1, the zerocheck polynomial will evaluate to
217    // 0. For all rounds, we can find a root of the zerocheck polynomial by finding a root of
218    // the eq term in the last coord.
219    // So for the first round, we need to find an additional 2 points (since the constraint
220    // polynomial is degree 3). We calculate the eval at points 2 and 4 (since we don't need to
221    // do any multiplications when interpolating the column evals).
222    // For the other rounds, we need to find an additional 1 point since we don't know the zercheck
223    // poly eval at point 0 and 1.
224    // We calculate the eval at point 0 and then infer the eval at point 1 by the passed in claim.
225    let mut xs = Vec::new();
226    let mut ys = Vec::new();
227
228    let (mut y_0, mut y_2, mut y_4) =
229        poly.air_data.sum_as_poly_in_last_variable::<K, IS_FIRST_ROUND>(
230            partial_lagrange.as_ref(),
231            poly.preprocessed_columns.as_ref(),
232            &poly.main_columns,
233        );
234
235    // Add the point 0 and it's eval to the xs and ys.
236    let virtual_geq = poly.virtual_geq;
237
238    let threshold_half = poly.main_columns.num_real_entries().div_ceil(2) - 1;
239    let msb_lagrange_eval: EF = poly.eq_adjustment
240        * if threshold_half < (1 << (poly.num_variables() - 1)) {
241            partial_lagrange.guts().as_buffer()[threshold_half]
242                .copy_into_host(partial_lagrange.backend())
243        } else {
244            EF::zero()
245        };
246
247    let virtual_0 = virtual_geq.fix_last_variable(EF::zero()).eval_at_usize(threshold_half);
248    let virtual_2 = virtual_geq.fix_last_variable(EF::two()).eval_at_usize(threshold_half);
249    let virtual_4 =
250        virtual_geq.fix_last_variable(EF::from_canonical_usize(4)).eval_at_usize(threshold_half);
251
252    xs.push(EF::zero());
253
254    let eq_last_term_factor = EF::one() - last;
255    y_0 *= eq_last_term_factor * poly.eq_adjustment;
256    y_0 -= poly.padded_row_adjustment * virtual_0 * msb_lagrange_eval * eq_last_term_factor;
257    ys.push(y_0);
258
259    // Add the point 1 and it's eval to the xs and ys.
260    xs.push(EF::one());
261
262    let y_1 = claim - y_0;
263    ys.push(y_1);
264
265    // Add the point 2 and it's eval to the xs and ys.
266    xs.push(EF::from_canonical_usize(2));
267    let eq_last_term_factor = last * F::from_canonical_usize(3) - EF::one();
268    y_2 *= eq_last_term_factor * poly.eq_adjustment;
269    y_2 -= poly.padded_row_adjustment * virtual_2 * msb_lagrange_eval * eq_last_term_factor;
270    ys.push(y_2);
271
272    // Add the point 4 and it's eval to the xs and ys.
273    xs.push(EF::from_canonical_usize(4));
274    let eq_last_term_factor = last * F::from_canonical_usize(7) - F::from_canonical_usize(3);
275    y_4 *= eq_last_term_factor * poly.eq_adjustment;
276    y_4 -= poly.padded_row_adjustment * virtual_4 * msb_lagrange_eval * eq_last_term_factor;
277    ys.push(y_4);
278
279    // Add the eq_first_term_root point and it's eval to the xs and ys.
280    let point_elements = poly.zeta.to_vec();
281    let point_first = point_elements.last().unwrap();
282    let b_const = (EF::one() - *point_first) / (EF::one() - point_first.double());
283    xs.push(b_const);
284    ys.push(EF::zero());
285
286    interpolate_univariate_polynomial(&xs, &ys)
287}
288
289/// This function will calculate the column values where the last variable is set to 0, 2, and 4
290/// and it's a non-padded variable.
291fn interpolate_last_var_non_padded_values<K: Field, const IS_FIRST_ROUND: bool>(
292    row_0: &[K],
293    row_1: &[K],
294    vals_0: &mut [K],
295    vals_2: &mut [K],
296    vals_4: &mut [K],
297) {
298    for (i, (row_0_val, row_1_val)) in row_0.iter().zip_eq(row_1.iter()).enumerate() {
299        let slope = *row_1_val - *row_0_val;
300        let slope_times_2 = slope + slope;
301        let slope_times_4 = slope_times_2 + slope_times_2;
302
303        vals_0[i] = *row_0_val;
304
305        vals_2[i] = slope_times_2 + *row_0_val;
306        vals_4[i] = slope_times_4 + *row_0_val;
307    }
308}
309
310/// The data required to produce zerocheck proofs on CPU.
311#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
312pub struct ZerocheckCpuProverData<A>(PhantomData<A>);
313
314impl<A> Default for ZerocheckCpuProverData<A> {
315    fn default() -> Self {
316        Self(PhantomData)
317    }
318}
319
320impl<A> ZerocheckCpuProverData<A> {
321    /// Creates a round prover for zerocheck.
322    pub fn round_prover<F, EF>(
323        air: Arc<A>,
324        public_values: Arc<Vec<F>>,
325        powers_of_alpha: Arc<Vec<EF>>,
326        gkr_powers: Arc<Vec<EF>>,
327    ) -> ZerocheckCpuProver<F, EF, A>
328    where
329        F: Field,
330        EF: ExtensionField<F>,
331        A: for<'b> Air<ConstraintSumcheckFolder<'b, F, F, EF>>
332            + for<'b> Air<ConstraintSumcheckFolder<'b, F, EF, EF>>
333            + MachineAir<F>,
334    {
335        ZerocheckCpuProver::new(air, public_values, powers_of_alpha, gkr_powers)
336    }
337}
338
339/// This function will calculate the column values where the last variable is set to 0, 2, and 4
340/// and it's a padded variable.  The `row_0` values are taken from the values matrix (which should
341/// have a height of 1).  The `row_1` values are all zero.
342#[must_use]
343pub fn interpolate_last_var_padded_values<K: Field>(values: &Mle<K>) -> (Vec<K>, Vec<K>, Vec<K>) {
344    let row_0 = values.guts().as_slice().iter();
345    let vals_0 = row_0.clone().copied().collect::<Vec<_>>();
346    let vals_2 = row_0.clone().map(|val| -(*val)).collect::<Vec<_>>();
347    let vals_4 = row_0.clone().map(|val| -K::from_canonical_usize(3) * (*val)).collect::<Vec<_>>();
348
349    (vals_0, vals_2, vals_4)
350}
351
352/// This function will increment the y0, y2, and y4 accumulators by the eval of the constraint
353/// polynomial at the points 0, 2, and 4.
354#[allow(clippy::too_many_arguments)]
355pub fn increment_y_values<
356    'a,
357    K: Field + From<F> + Add<F, Output = K> + Sub<F, Output = K> + Mul<F, Output = K>,
358    F: Field,
359    EF: ExtensionField<F> + From<K> + ExtensionField<F> + AbstractExtensionField<K>,
360    A: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
361    const IS_FIRST_ROUND: bool,
362>(
363    public_values: &[F],
364    powers_of_alpha: &[EF],
365    air: &A,
366    y_0: &mut EF,
367    y_2: &mut EF,
368    y_4: &mut EF,
369    preprocessed_column_vals_0: &[K],
370    main_column_vals_0: &[K],
371    preprocessed_column_vals_2: &[K],
372    main_column_vals_2: &[K],
373    preprocessed_column_vals_4: &[K],
374    main_column_vals_4: &[K],
375    interaction_batching_powers: &[EF],
376    eq: EF,
377) {
378    let mut y_0_adjustment = EF::zero();
379    // Add to the y_0 accumulator.
380    if !IS_FIRST_ROUND {
381        let mut folder = ConstraintSumcheckFolder {
382            preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_0),
383            main: RowMajorMatrixView::new_row(main_column_vals_0),
384            accumulator: EF::zero(),
385            public_values,
386            constraint_index: 0,
387            powers_of_alpha,
388        };
389        air.eval(&mut folder);
390        y_0_adjustment += folder.accumulator;
391    }
392
393    let gkr_adjustment_0 = main_column_vals_0
394        .iter()
395        .copied()
396        .chain(preprocessed_column_vals_0.iter().copied())
397        .zip(interaction_batching_powers.iter().copied())
398        .map(|(val, power)| power * val)
399        .sum::<EF>();
400
401    y_0_adjustment += gkr_adjustment_0;
402    *y_0 += y_0_adjustment * eq;
403
404    let mut y_2_adjustment = EF::zero();
405
406    // Add to the y_2 accumulator.
407    let mut folder = ConstraintSumcheckFolder {
408        preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_2),
409        main: RowMajorMatrixView::new_row(main_column_vals_2),
410        accumulator: EF::zero(),
411        public_values,
412        constraint_index: 0,
413        powers_of_alpha,
414    };
415    air.eval(&mut folder);
416
417    y_2_adjustment += folder.accumulator;
418    let gkr_adjustment_2 = main_column_vals_2
419        .iter()
420        .copied()
421        .chain(preprocessed_column_vals_2.iter().copied())
422        .zip(interaction_batching_powers.iter().copied())
423        .map(|(val, power)| power * val)
424        .sum::<EF>();
425    y_2_adjustment += gkr_adjustment_2;
426    *y_2 += y_2_adjustment * eq;
427
428    // Add to the y_4 accumulator.
429    let mut folder = ConstraintSumcheckFolder {
430        preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_4),
431        main: RowMajorMatrixView::new_row(main_column_vals_4),
432        accumulator: EF::zero(),
433        public_values,
434        constraint_index: 0,
435        powers_of_alpha,
436    };
437    let gkr_adjustment_4 = gkr_adjustment_2 + gkr_adjustment_2 - gkr_adjustment_0;
438    air.eval(&mut folder);
439    *y_4 += (folder.accumulator + gkr_adjustment_4) * eq;
440}