1use std::{
2 marker::PhantomData,
3 ops::{Add, Mul, Sub},
4 sync::Arc,
5};
6
7use itertools::Itertools;
8use rayon::iter::{ParallelBridge, ParallelIterator};
9use serde::{Deserialize, Serialize};
10use slop_air::Air;
11use slop_algebra::{
12 interpolate_univariate_polynomial, AbstractExtensionField, ExtensionField, Field,
13 UnivariatePolynomial,
14};
15use slop_matrix::dense::RowMajorMatrixView;
16use slop_multilinear::{Mle, PaddedMle};
17use slop_sumcheck::SumcheckPolyBase;
18
19use crate::{air::MachineAir, ConstraintSumcheckFolder};
20use slop_alloc::HasBackend;
21
22use super::ZeroCheckPoly;
23
24#[derive(Clone)]
26pub struct ZerocheckCpuProver<F, EF, A> {
27 air: Arc<A>,
29 public_values: Arc<Vec<F>>,
31 powers_of_alpha: Arc<Vec<EF>>,
33 gkr_powers: Arc<Vec<EF>>,
34}
35
36impl<F, EF, A> ZerocheckCpuProver<F, EF, A> {
37 pub fn new(
39 air: Arc<A>,
40 public_values: Arc<Vec<F>>,
41 powers_of_alpha: Arc<Vec<EF>>,
42 gkr_powers: Arc<Vec<EF>>,
43 ) -> Self {
44 Self { air, public_values, powers_of_alpha, gkr_powers }
45 }
46}
47
48impl<F, EF, A> ZerocheckCpuProver<F, EF, A>
49where
50 F: Field,
51 EF: ExtensionField<F>,
52{
53 pub(crate) fn sum_as_poly_in_last_variable<K, const IS_FIRST_ROUND: bool>(
54 &self,
55 partial_lagrange: &Mle<EF>,
56 preprocessed_values: Option<&PaddedMle<K>>,
57 main_values: &PaddedMle<K>,
58 ) -> (EF, EF, EF)
59 where
60 K: ExtensionField<F>,
61 EF: ExtensionField<K>,
62 A: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
63 {
64 let air = self.air.clone();
65 let public_values = self.public_values.clone();
66 let powers_of_alpha = self.powers_of_alpha.clone();
67 let gkr_powers = self.gkr_powers.clone();
68 {
69 let num_non_padded_terms = main_values.num_real_entries().div_ceil(2);
70 let eq_chunk_size = std::cmp::max(num_non_padded_terms / num_cpus::get(), 1);
71 let values_chunk_size = eq_chunk_size * 2;
72
73 let eq_guts = partial_lagrange.guts().as_buffer().as_slice();
74
75 let num_main_columns = main_values.num_polynomials();
76 let num_preprocessed_columns =
77 preprocessed_values.map_or(0, slop_multilinear::PaddedMle::num_polynomials);
78
79 let main_values = main_values.inner().as_ref().unwrap().guts().as_buffer().as_slice();
80 let has_preprocessed_values = preprocessed_values.is_some();
81 let preprocessed_values = preprocessed_values.as_ref().map_or([].as_slice(), |p| {
82 p.inner().as_ref().unwrap().guts().as_buffer().as_slice()
83 });
84
85 let eq_guts = eq_guts[0..num_non_padded_terms].to_vec();
87
88 let cumul_ys = eq_guts
89 .chunks(eq_chunk_size)
90 .zip(main_values.chunks(values_chunk_size * num_main_columns))
91 .enumerate()
92 .par_bridge()
93 .map(|(i, (eq_chunk, main_chunk))| {
94 let mut cumul_y_0 = EF::zero();
97 let mut cumul_y_2 = EF::zero();
98 let mut cumul_y_4 = EF::zero();
99
100 let mut main_values_0 = vec![K::zero(); num_main_columns];
101 let mut main_values_2 = vec![K::zero(); num_main_columns];
102 let mut main_values_4 = vec![K::zero(); num_main_columns];
103
104 let mut preprocessed_values_0 = vec![K::zero(); num_preprocessed_columns];
105 let mut preprocessed_values_2 = vec![K::zero(); num_preprocessed_columns];
106 let mut preprocessed_values_4 = vec![K::zero(); num_preprocessed_columns];
107
108 for (j, (eq, main_row)) in
109 eq_chunk.iter().zip(main_chunk.chunks(num_main_columns * 2)).enumerate()
110 {
111 let main_row_0 = &main_row[0..num_main_columns];
112 let main_row_1 = if main_row.len() == 2 * num_main_columns {
113 &main_row[num_main_columns..num_main_columns * 2]
114 } else {
115 &vec![K::zero(); num_main_columns]
117 };
118
119 interpolate_last_var_non_padded_values::<K, IS_FIRST_ROUND>(
120 main_row_0,
121 main_row_1,
122 &mut main_values_0,
123 &mut main_values_2,
124 &mut main_values_4,
125 );
126
127 if has_preprocessed_values {
128 let preprocess_chunk_size =
129 values_chunk_size * num_preprocessed_columns;
130 let preprocessed_row_0_start_idx =
131 i * preprocess_chunk_size + 2 * j * num_preprocessed_columns;
132 let preprocessed_row_0 = &preprocessed_values
133 [preprocessed_row_0_start_idx
134 ..preprocessed_row_0_start_idx + num_preprocessed_columns];
135 let preprocessed_row_1_start_idx =
136 preprocessed_row_0_start_idx + num_preprocessed_columns;
137 let preprocessed_row_1 =
138 if preprocessed_values.len() != preprocessed_row_1_start_idx {
139 &preprocessed_values[preprocessed_row_1_start_idx
140 ..preprocessed_row_1_start_idx + num_preprocessed_columns]
141 } else {
142 &vec![K::zero(); num_preprocessed_columns]
144 };
145
146 interpolate_last_var_non_padded_values::<K, IS_FIRST_ROUND>(
147 preprocessed_row_0,
148 preprocessed_row_1,
149 &mut preprocessed_values_0,
150 &mut preprocessed_values_2,
151 &mut preprocessed_values_4,
152 );
153 }
154
155 increment_y_values::<K, F, EF, A, IS_FIRST_ROUND>(
156 &public_values,
157 &powers_of_alpha,
158 &air,
159 &mut cumul_y_0,
160 &mut cumul_y_2,
161 &mut cumul_y_4,
162 &preprocessed_values_0,
163 &main_values_0,
164 &preprocessed_values_2,
165 &main_values_2,
166 &preprocessed_values_4,
167 &main_values_4,
168 &gkr_powers,
169 *eq,
170 );
171 }
172 (cumul_y_0, cumul_y_2, cumul_y_4)
173 })
174 .collect::<Vec<_>>();
175
176 cumul_ys.into_iter().fold(
177 (EF::zero(), EF::zero(), EF::zero()),
178 |(y_0, y_2, y_4), (y_0_i, y_2_i, y_4_i)| (y_0 + y_0_i, y_2 + y_2_i, y_4 + y_4_i),
179 )
180 }
181 }
182}
183
184pub fn zerocheck_sum_as_poly_in_last_variable<
188 K: ExtensionField<F>,
189 F: Field,
190 EF: ExtensionField<F> + ExtensionField<K> + ExtensionField<F> + AbstractExtensionField<K>,
191 AirData,
192 const IS_FIRST_ROUND: bool,
193>(
194 poly: &ZeroCheckPoly<K, F, EF, AirData>,
195 claim: Option<EF>,
196) -> UnivariatePolynomial<EF>
197where
198 AirData: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
199{
200 let num_real_entries = poly.main_columns.num_real_entries();
201 if num_real_entries == 0 {
202 return UnivariatePolynomial::zero(4);
205 }
206
207 let claim = claim.expect("claim must be provided");
208
209 let (rest_point_host, last) = poly.zeta.split_at(poly.zeta.dimension() - 1);
210 let last = *last[0];
211
212 let partial_lagrange: Mle<EF> = Mle::partial_lagrange(&rest_point_host);
214 let partial_lagrange = Arc::new(partial_lagrange);
215
216 let mut xs = Vec::new();
226 let mut ys = Vec::new();
227
228 let (mut y_0, mut y_2, mut y_4) =
229 poly.air_data.sum_as_poly_in_last_variable::<K, IS_FIRST_ROUND>(
230 partial_lagrange.as_ref(),
231 poly.preprocessed_columns.as_ref(),
232 &poly.main_columns,
233 );
234
235 let virtual_geq = poly.virtual_geq;
237
238 let threshold_half = poly.main_columns.num_real_entries().div_ceil(2) - 1;
239 let msb_lagrange_eval: EF = poly.eq_adjustment
240 * if threshold_half < (1 << (poly.num_variables() - 1)) {
241 partial_lagrange.guts().as_buffer()[threshold_half]
242 .copy_into_host(partial_lagrange.backend())
243 } else {
244 EF::zero()
245 };
246
247 let virtual_0 = virtual_geq.fix_last_variable(EF::zero()).eval_at_usize(threshold_half);
248 let virtual_2 = virtual_geq.fix_last_variable(EF::two()).eval_at_usize(threshold_half);
249 let virtual_4 =
250 virtual_geq.fix_last_variable(EF::from_canonical_usize(4)).eval_at_usize(threshold_half);
251
252 xs.push(EF::zero());
253
254 let eq_last_term_factor = EF::one() - last;
255 y_0 *= eq_last_term_factor * poly.eq_adjustment;
256 y_0 -= poly.padded_row_adjustment * virtual_0 * msb_lagrange_eval * eq_last_term_factor;
257 ys.push(y_0);
258
259 xs.push(EF::one());
261
262 let y_1 = claim - y_0;
263 ys.push(y_1);
264
265 xs.push(EF::from_canonical_usize(2));
267 let eq_last_term_factor = last * F::from_canonical_usize(3) - EF::one();
268 y_2 *= eq_last_term_factor * poly.eq_adjustment;
269 y_2 -= poly.padded_row_adjustment * virtual_2 * msb_lagrange_eval * eq_last_term_factor;
270 ys.push(y_2);
271
272 xs.push(EF::from_canonical_usize(4));
274 let eq_last_term_factor = last * F::from_canonical_usize(7) - F::from_canonical_usize(3);
275 y_4 *= eq_last_term_factor * poly.eq_adjustment;
276 y_4 -= poly.padded_row_adjustment * virtual_4 * msb_lagrange_eval * eq_last_term_factor;
277 ys.push(y_4);
278
279 let point_elements = poly.zeta.to_vec();
281 let point_first = point_elements.last().unwrap();
282 let b_const = (EF::one() - *point_first) / (EF::one() - point_first.double());
283 xs.push(b_const);
284 ys.push(EF::zero());
285
286 interpolate_univariate_polynomial(&xs, &ys)
287}
288
289fn interpolate_last_var_non_padded_values<K: Field, const IS_FIRST_ROUND: bool>(
292 row_0: &[K],
293 row_1: &[K],
294 vals_0: &mut [K],
295 vals_2: &mut [K],
296 vals_4: &mut [K],
297) {
298 for (i, (row_0_val, row_1_val)) in row_0.iter().zip_eq(row_1.iter()).enumerate() {
299 let slope = *row_1_val - *row_0_val;
300 let slope_times_2 = slope + slope;
301 let slope_times_4 = slope_times_2 + slope_times_2;
302
303 vals_0[i] = *row_0_val;
304
305 vals_2[i] = slope_times_2 + *row_0_val;
306 vals_4[i] = slope_times_4 + *row_0_val;
307 }
308}
309
310#[derive(Clone, Debug, Copy, Serialize, Deserialize)]
312pub struct ZerocheckCpuProverData<A>(PhantomData<A>);
313
314impl<A> Default for ZerocheckCpuProverData<A> {
315 fn default() -> Self {
316 Self(PhantomData)
317 }
318}
319
320impl<A> ZerocheckCpuProverData<A> {
321 pub fn round_prover<F, EF>(
323 air: Arc<A>,
324 public_values: Arc<Vec<F>>,
325 powers_of_alpha: Arc<Vec<EF>>,
326 gkr_powers: Arc<Vec<EF>>,
327 ) -> ZerocheckCpuProver<F, EF, A>
328 where
329 F: Field,
330 EF: ExtensionField<F>,
331 A: for<'b> Air<ConstraintSumcheckFolder<'b, F, F, EF>>
332 + for<'b> Air<ConstraintSumcheckFolder<'b, F, EF, EF>>
333 + MachineAir<F>,
334 {
335 ZerocheckCpuProver::new(air, public_values, powers_of_alpha, gkr_powers)
336 }
337}
338
339#[must_use]
343pub fn interpolate_last_var_padded_values<K: Field>(values: &Mle<K>) -> (Vec<K>, Vec<K>, Vec<K>) {
344 let row_0 = values.guts().as_slice().iter();
345 let vals_0 = row_0.clone().copied().collect::<Vec<_>>();
346 let vals_2 = row_0.clone().map(|val| -(*val)).collect::<Vec<_>>();
347 let vals_4 = row_0.clone().map(|val| -K::from_canonical_usize(3) * (*val)).collect::<Vec<_>>();
348
349 (vals_0, vals_2, vals_4)
350}
351
352#[allow(clippy::too_many_arguments)]
355pub fn increment_y_values<
356 'a,
357 K: Field + From<F> + Add<F, Output = K> + Sub<F, Output = K> + Mul<F, Output = K>,
358 F: Field,
359 EF: ExtensionField<F> + From<K> + ExtensionField<F> + AbstractExtensionField<K>,
360 A: for<'b> Air<ConstraintSumcheckFolder<'b, F, K, EF>> + MachineAir<F>,
361 const IS_FIRST_ROUND: bool,
362>(
363 public_values: &[F],
364 powers_of_alpha: &[EF],
365 air: &A,
366 y_0: &mut EF,
367 y_2: &mut EF,
368 y_4: &mut EF,
369 preprocessed_column_vals_0: &[K],
370 main_column_vals_0: &[K],
371 preprocessed_column_vals_2: &[K],
372 main_column_vals_2: &[K],
373 preprocessed_column_vals_4: &[K],
374 main_column_vals_4: &[K],
375 interaction_batching_powers: &[EF],
376 eq: EF,
377) {
378 let mut y_0_adjustment = EF::zero();
379 if !IS_FIRST_ROUND {
381 let mut folder = ConstraintSumcheckFolder {
382 preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_0),
383 main: RowMajorMatrixView::new_row(main_column_vals_0),
384 accumulator: EF::zero(),
385 public_values,
386 constraint_index: 0,
387 powers_of_alpha,
388 };
389 air.eval(&mut folder);
390 y_0_adjustment += folder.accumulator;
391 }
392
393 let gkr_adjustment_0 = main_column_vals_0
394 .iter()
395 .copied()
396 .chain(preprocessed_column_vals_0.iter().copied())
397 .zip(interaction_batching_powers.iter().copied())
398 .map(|(val, power)| power * val)
399 .sum::<EF>();
400
401 y_0_adjustment += gkr_adjustment_0;
402 *y_0 += y_0_adjustment * eq;
403
404 let mut y_2_adjustment = EF::zero();
405
406 let mut folder = ConstraintSumcheckFolder {
408 preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_2),
409 main: RowMajorMatrixView::new_row(main_column_vals_2),
410 accumulator: EF::zero(),
411 public_values,
412 constraint_index: 0,
413 powers_of_alpha,
414 };
415 air.eval(&mut folder);
416
417 y_2_adjustment += folder.accumulator;
418 let gkr_adjustment_2 = main_column_vals_2
419 .iter()
420 .copied()
421 .chain(preprocessed_column_vals_2.iter().copied())
422 .zip(interaction_batching_powers.iter().copied())
423 .map(|(val, power)| power * val)
424 .sum::<EF>();
425 y_2_adjustment += gkr_adjustment_2;
426 *y_2 += y_2_adjustment * eq;
427
428 let mut folder = ConstraintSumcheckFolder {
430 preprocessed: RowMajorMatrixView::new_row(preprocessed_column_vals_4),
431 main: RowMajorMatrixView::new_row(main_column_vals_4),
432 accumulator: EF::zero(),
433 public_values,
434 constraint_index: 0,
435 powers_of_alpha,
436 };
437 let gkr_adjustment_4 = gkr_adjustment_2 + gkr_adjustment_2 - gkr_adjustment_0;
438 air.eval(&mut folder);
439 *y_4 += (folder.accumulator + gkr_adjustment_4) * eq;
440}