1use std::{
2 collections::{BTreeMap, BTreeSet},
3 marker::PhantomData,
4};
5
6use itertools::Itertools;
7use slop_algebra::{ExtensionField, Field};
8use slop_challenger::{FieldChallenger, VariableLengthChallenger};
9use slop_multilinear::{
10 full_geq, partial_lagrange_blocking, Mle, MleEval, MultilinearPcsChallenger, Point,
11};
12use slop_sumcheck::{partially_verify_sumcheck_proof, SumcheckError};
13use thiserror::Error;
14
15use crate::{air::MachineAir, Chip};
16
17use super::{ChipEvaluation, LogUpEvaluations, LogUpGkrOutput, LogupGkrProof};
18
19#[derive(Debug, Error)]
21pub enum LogupGkrVerificationError<EF> {
22 #[error("inconsistent sumcheck claim at round {0}")]
24 InconsistentSumcheckClaim(usize),
25 #[error("inconsistent evaluation at round {0}")]
27 InconsistentEvaluation(usize),
28 #[error("sumcheck error: {0}")]
30 SumcheckError(#[from] SumcheckError),
31 #[error("invalid shape")]
33 InvalidShape,
34 #[error("invalid first layer dimension: {0} != {1}")]
36 InvalidFirstLayerDimension(u32, u32),
37 #[error("invalid last layer dimension: {0} != {1}")]
39 InvalidLastLayerDimension(usize, usize),
40 #[error("trace point mismatch")]
42 TracePointMismatch,
43 #[error("cumulative sum mismatch: {0} != {1}")]
45 CumulativeSumMismatch(EF, EF),
46 #[error("numerator evaluation mismatch: {0} != {1}")]
48 NumeratorEvaluationMismatch(EF, EF),
49 #[error("denominator evaluation mismatch: {0} != {1}")]
51 DenominatorEvaluationMismatch(EF, EF),
52 #[error("denominator evaluation has zero value")]
54 ZeroDenominator,
55}
56
57#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Hash)]
59pub struct LogUpGkrVerifier<F, EF, A>(PhantomData<(F, EF, A)>);
60
61impl<F, EF, A> LogUpGkrVerifier<F, EF, A>
62where
63 F: Field,
64 EF: ExtensionField<F>,
65 A: MachineAir<F>,
66{
67 #[allow(clippy::too_many_arguments)]
71 #[allow(clippy::too_many_lines)]
72 pub fn verify_logup_gkr(
73 shard_chips: &BTreeSet<Chip<F, A>>,
74 degrees: &BTreeMap<String, Point<F>>,
75 alpha: EF,
76 beta_seed: &Point<EF>,
77 cumulative_sum: EF,
78 max_log_row_count: usize,
79 proof: &LogupGkrProof<EF>,
80 challenger: &mut impl FieldChallenger<F>,
81 ) -> Result<(), LogupGkrVerificationError<EF>> {
82 let LogupGkrProof { circuit_output, round_proofs, logup_evaluations } = proof;
83
84 let LogUpGkrOutput { numerator, denominator } = circuit_output;
85
86 let num_of_interactions =
88 shard_chips.iter().map(|c| c.sends().len() + c.receives().len()).sum::<usize>();
89 let number_of_interaction_variables = num_of_interactions.next_power_of_two().ilog2();
90
91 let expected_size = 1 << (number_of_interaction_variables + 1);
92
93 if numerator.guts().dimensions.sizes() != [expected_size, 1]
94 || denominator.guts().dimensions.sizes() != [expected_size, 1]
95 {
96 return Err(LogupGkrVerificationError::InvalidShape);
97 }
98
99 challenger.observe_variable_length_extension_slice(numerator.guts().as_slice());
101 challenger.observe_variable_length_extension_slice(denominator.guts().as_slice());
102
103 if denominator.guts().as_slice().iter().any(slop_algebra::Field::is_zero) {
104 return Err(LogupGkrVerificationError::ZeroDenominator);
105 }
106
107 let output_cumulative_sum = numerator
109 .guts()
110 .as_slice()
111 .iter()
112 .zip_eq(denominator.guts().as_slice().iter())
113 .map(|(n, d)| *n / *d)
114 .sum::<EF>();
115 if output_cumulative_sum != cumulative_sum {
116 return Err(LogupGkrVerificationError::CumulativeSumMismatch(
117 output_cumulative_sum,
118 cumulative_sum,
119 ));
120 }
121
122 let initial_number_of_variables = numerator.num_variables();
124 if initial_number_of_variables != number_of_interaction_variables + 1 {
125 return Err(LogupGkrVerificationError::InvalidFirstLayerDimension(
126 initial_number_of_variables,
127 number_of_interaction_variables + 1,
128 ));
129 }
130 let first_eval_point = challenger.sample_point::<EF>(initial_number_of_variables);
132
133 let mut numerator_eval = numerator.blocking_eval_at(&first_eval_point)[0];
135 let mut denominator_eval = denominator.blocking_eval_at(&first_eval_point)[0];
136 let mut eval_point = first_eval_point;
137
138 if round_proofs.len() + 1 != max_log_row_count {
139 return Err(LogupGkrVerificationError::InvalidShape);
140 }
141
142 for (i, round_proof) in round_proofs.iter().enumerate() {
143 let lambda = challenger.sample_ext_element::<EF>();
145 let expected_claim = numerator_eval * lambda + denominator_eval;
147 if round_proof.sumcheck_proof.claimed_sum != expected_claim {
148 return Err(LogupGkrVerificationError::InconsistentSumcheckClaim(i));
149 }
150 partially_verify_sumcheck_proof(
152 &round_proof.sumcheck_proof,
153 challenger,
154 i + number_of_interaction_variables as usize + 1,
155 3,
156 )?;
157 let (point, final_eval) = round_proof.sumcheck_proof.point_and_eval.clone();
159 let eq_eval = Mle::full_lagrange_eval(&point, &eval_point);
160 let numerator_sumcheck_eval = round_proof.numerator_0 * round_proof.denominator_1
161 + round_proof.numerator_1 * round_proof.denominator_0;
162 let denominator_sumcheck_eval = round_proof.denominator_0 * round_proof.denominator_1;
163 let expected_final_eval =
164 eq_eval * (numerator_sumcheck_eval * lambda + denominator_sumcheck_eval);
165 if final_eval != expected_final_eval {
166 return Err(LogupGkrVerificationError::InconsistentEvaluation(i));
167 }
168
169 challenger.observe_ext_element(round_proof.numerator_0);
171 challenger.observe_ext_element(round_proof.numerator_1);
172 challenger.observe_ext_element(round_proof.denominator_0);
173 challenger.observe_ext_element(round_proof.denominator_1);
174
175 eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
177 let last_coordinate = challenger.sample_ext_element::<EF>();
179 eval_point.add_dimension_back(last_coordinate);
180 numerator_eval = round_proof.numerator_0
182 + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
183 denominator_eval = round_proof.denominator_0
184 + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
185 }
186
187 let (interaction_point, trace_point) =
189 eval_point.split_at(number_of_interaction_variables as usize);
190 let trace_variables = trace_point.dimension();
192 if trace_variables != max_log_row_count {
193 return Err(LogupGkrVerificationError::InvalidLastLayerDimension(
194 trace_variables,
195 max_log_row_count,
196 ));
197 }
198
199 let LogUpEvaluations { point, chip_openings } = logup_evaluations;
201 if point != &trace_point {
202 return Err(LogupGkrVerificationError::TracePointMismatch);
203 }
204
205 let betas = partial_lagrange_blocking(beta_seed);
206
207 let mut numerator_values = Vec::with_capacity(num_of_interactions);
210 let mut denominator_values = Vec::with_capacity(num_of_interactions);
211 let mut point_extended = point.clone();
212 point_extended.add_dimension(EF::zero());
213 let len = shard_chips.len();
214 challenger.observe(F::from_canonical_usize(len));
215 for ((chip, openings), threshold) in
216 shard_chips.iter().zip_eq(chip_openings.values()).zip_eq(degrees.values())
217 {
218 if let Some(prep_eval) = openings.preprocessed_trace_evaluations.as_ref() {
220 challenger.observe_variable_length_extension_slice(prep_eval);
221 if prep_eval.evaluations().sizes() != [chip.air.preprocessed_width()] {
222 return Err(LogupGkrVerificationError::InvalidShape);
223 }
224 } else if chip.air.preprocessed_width() != 0 {
225 return Err(LogupGkrVerificationError::InvalidShape);
226 }
227 challenger.observe_variable_length_extension_slice(&openings.main_trace_evaluations);
228 if openings.main_trace_evaluations.evaluations().sizes() != [chip.air.width()] {
229 return Err(LogupGkrVerificationError::InvalidShape);
230 }
231
232 if threshold.dimension() != point_extended.dimension() {
233 return Err(LogupGkrVerificationError::InvalidShape);
234 }
235
236 let geq_eval = full_geq(threshold, &point_extended);
237 let ChipEvaluation { main_trace_evaluations, preprocessed_trace_evaluations } =
238 openings;
239 for (interaction, is_send) in chip
240 .sends()
241 .iter()
242 .map(|s| (s, true))
243 .chain(chip.receives().iter().map(|r| (r, false)))
244 {
245 let (real_numerator, real_denominator) = interaction.eval(
246 preprocessed_trace_evaluations.as_ref(),
247 main_trace_evaluations,
248 alpha,
249 betas.as_slice(),
250 );
251 let padding_trace_opening =
252 MleEval::from(vec![EF::zero(); main_trace_evaluations.num_polynomials()]);
253 let padding_preprocessed_opening = preprocessed_trace_evaluations
254 .as_ref()
255 .map(|eval| MleEval::from(vec![EF::zero(); eval.num_polynomials()]));
256 let (padding_numerator, padding_denominator) = interaction.eval(
257 padding_preprocessed_opening.as_ref(),
258 &padding_trace_opening,
259 alpha,
260 betas.as_slice(),
261 );
262
263 let numerator_eval = real_numerator - padding_numerator * geq_eval;
264 let denominator_eval =
265 real_denominator + (EF::one() - padding_denominator) * geq_eval;
266 let numerator_eval = if is_send { numerator_eval } else { -numerator_eval };
267 numerator_values.push(numerator_eval);
268 denominator_values.push(denominator_eval);
269 }
270 }
271 numerator_values.resize(1 << interaction_point.dimension(), EF::zero());
274 let numerator = Mle::from(numerator_values);
275 denominator_values.resize(1 << interaction_point.dimension(), EF::one());
277 let denominator = Mle::from(denominator_values);
278
279 let expected_numerator_eval = numerator.blocking_eval_at(&interaction_point)[0];
280 let expected_denominator_eval = denominator.blocking_eval_at(&interaction_point)[0];
281 if numerator_eval != expected_numerator_eval {
282 return Err(LogupGkrVerificationError::NumeratorEvaluationMismatch(
283 numerator_eval,
284 expected_numerator_eval,
285 ));
286 }
287 if denominator_eval != expected_denominator_eval {
288 return Err(LogupGkrVerificationError::DenominatorEvaluationMismatch(
289 denominator_eval,
290 expected_denominator_eval,
291 ));
292 }
293 Ok(())
294 }
295}