1use crate::prover::Record;
2use crate::record::MachineRecord;
3use crate::VerifierPublicValuesConstraintFolder;
4use crate::GKR_GRINDING_BITS;
5use crate::{air::MachineAir, Chip, ShardContext};
6use itertools::Itertools;
7use slop_air::BaseAir;
8use slop_algebra::AbstractField;
9use slop_challenger::GrindingChallenger;
10use slop_challenger::{CanObserve, FieldChallenger, IopCtx, VariableLengthChallenger};
11use slop_multilinear::{
12 full_geq, partial_lagrange_blocking, Mle, MleEval, MultilinearPcsChallenger, Point,
13};
14use slop_sumcheck::{partially_verify_sumcheck_proof, SumcheckError};
15use std::cmp::max;
16use std::{
17 collections::{BTreeMap, BTreeSet},
18 marker::PhantomData,
19};
20use thiserror::Error;
21
22use super::{ChipEvaluation, LogUpEvaluations, LogUpGkrOutput, LogupGkrProof};
23
24#[derive(Debug, Error)]
26pub enum LogupGkrVerificationError<EF> {
27 #[error("inconsistent sumcheck claim at round {0}")]
29 InconsistentSumcheckClaim(usize),
30 #[error("inconsistent evaluation at round {0}")]
32 InconsistentEvaluation(usize),
33 #[error("sumcheck error: {0}")]
35 SumcheckError(#[from] SumcheckError),
36 #[error("invalid shape")]
38 InvalidShape,
39 #[error("invalid first layer dimension: {0} != {1}")]
41 InvalidFirstLayerDimension(u32, u32),
42 #[error("invalid last layer dimension: {0} != {1}")]
44 InvalidLastLayerDimension(usize, usize),
45 #[error("trace point mismatch")]
47 TracePointMismatch,
48 #[error("cumulative sum mismatch: {0} != {1}")]
50 CumulativeSumMismatch(EF, EF),
51 #[error("numerator evaluation mismatch: {0} != {1}")]
53 NumeratorEvaluationMismatch(EF, EF),
54 #[error("denominator evaluation mismatch: {0} != {1}")]
56 DenominatorEvaluationMismatch(EF, EF),
57 #[error("denominator evaluation has zero value")]
59 ZeroDenominator,
60 #[error("Invalid proof of work witness")]
62 Pow,
63 #[error("public values verification failed")]
65 InvalidPublicValues,
66}
67
68#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Hash)]
70pub struct LogUpGkrVerifier<GC, SC>(PhantomData<(GC, SC)>);
71
72impl<GC: IopCtx, SC: ShardContext<GC>> LogUpGkrVerifier<GC, SC> {
73 pub fn verify_public_values(
75 challenge: GC::EF,
76 alpha: &GC::EF,
77 beta_seed: &Point<GC::EF>,
78 public_values: &[GC::F],
79 ) -> Result<GC::EF, LogupGkrVerificationError<GC::EF>> {
80 let betas = slop_multilinear::partial_lagrange_blocking(beta_seed).into_buffer().into_vec();
81 let mut folder = VerifierPublicValuesConstraintFolder::<GC> {
82 perm_challenges: (alpha, &betas),
83 alpha: challenge,
84 accumulator: GC::EF::zero(),
85 local_interaction_digest: GC::EF::zero(),
86 public_values,
87 _marker: PhantomData,
88 };
89 Record::<_, SC>::eval_public_values(&mut folder);
90 if folder.accumulator == GC::EF::zero() {
91 Ok(folder.local_interaction_digest)
92 } else {
93 Err(LogupGkrVerificationError::InvalidPublicValues)
94 }
95 }
96
97 #[allow(clippy::too_many_arguments)]
101 #[allow(clippy::too_many_lines)]
102 pub fn verify_logup_gkr(
103 shard_chips: &BTreeSet<Chip<GC::F, SC::Air>>,
104 degrees: &BTreeMap<String, Point<GC::F>>,
105 max_log_row_count: usize,
106 proof: &LogupGkrProof<<GC::Challenger as GrindingChallenger>::Witness, GC::EF>,
107 public_values: &[GC::F],
108 challenger: &mut GC::Challenger,
109 ) -> Result<(), LogupGkrVerificationError<GC::EF>> {
110 let LogupGkrProof { circuit_output, round_proofs, logup_evaluations, witness } = proof;
111
112 let LogUpGkrOutput { numerator, denominator } = circuit_output;
113 let max_interaction_arity = shard_chips
114 .iter()
115 .flat_map(|c| c.sends().iter().chain(c.receives().iter()))
116 .map(|i| i.values.len() + 1)
117 .max()
118 .unwrap();
119
120 let max_interaction_kinds_values = Record::<_, SC>::interactions_in_public_values()
121 .iter()
122 .map(|kind| kind.num_values() + 1)
123 .max()
124 .unwrap_or(1);
125 let beta_seed_dim =
126 max(max_interaction_arity, max_interaction_kinds_values).next_power_of_two().ilog2();
127
128 if !challenger.check_witness(GKR_GRINDING_BITS, *witness) {
131 return Err(LogupGkrVerificationError::Pow);
132 }
133
134 let alpha = challenger.sample_ext_element::<GC::EF>();
135 let beta_seed = (0..beta_seed_dim)
136 .map(|_| challenger.sample_ext_element::<GC::EF>())
137 .collect::<Point<_>>();
138 let pv_challenge = challenger.sample_ext_element::<GC::EF>();
139 let cumulative_sum = -LogUpGkrVerifier::<GC, SC>::verify_public_values(
140 pv_challenge,
141 &alpha,
142 &beta_seed,
143 public_values,
144 )?;
145
146 let num_of_interactions =
148 shard_chips.iter().map(|c| c.sends().len() + c.receives().len()).sum::<usize>();
149 let number_of_interaction_variables = num_of_interactions.next_power_of_two().ilog2();
150
151 let expected_size = 1 << (number_of_interaction_variables + 1);
152
153 if numerator.guts().dimensions.sizes() != [expected_size, 1]
154 || denominator.guts().dimensions.sizes() != [expected_size, 1]
155 {
156 return Err(LogupGkrVerificationError::InvalidShape);
157 }
158
159 challenger.observe_variable_length_extension_slice(numerator.guts().as_slice());
161 challenger.observe_variable_length_extension_slice(denominator.guts().as_slice());
162
163 if denominator.guts().as_slice().iter().any(slop_algebra::Field::is_zero) {
164 return Err(LogupGkrVerificationError::ZeroDenominator);
165 }
166
167 let output_cumulative_sum = numerator
169 .guts()
170 .as_slice()
171 .iter()
172 .zip_eq(denominator.guts().as_slice().iter())
173 .map(|(n, d)| *n / *d)
174 .sum::<GC::EF>();
175 if output_cumulative_sum != cumulative_sum {
176 return Err(LogupGkrVerificationError::CumulativeSumMismatch(
177 output_cumulative_sum,
178 cumulative_sum,
179 ));
180 }
181
182 let initial_number_of_variables = numerator.num_variables();
184 if initial_number_of_variables != number_of_interaction_variables + 1 {
185 return Err(LogupGkrVerificationError::InvalidFirstLayerDimension(
186 initial_number_of_variables,
187 number_of_interaction_variables + 1,
188 ));
189 }
190 let first_eval_point = challenger.sample_point::<GC::EF>(initial_number_of_variables);
192
193 let mut numerator_eval = numerator.blocking_eval_at(&first_eval_point)[0];
195 let mut denominator_eval = denominator.blocking_eval_at(&first_eval_point)[0];
196 let mut eval_point = first_eval_point;
197
198 if round_proofs.len() + 1 != max_log_row_count {
199 return Err(LogupGkrVerificationError::InvalidShape);
200 }
201
202 for (i, round_proof) in round_proofs.iter().enumerate() {
203 let lambda = challenger.sample_ext_element::<GC::EF>();
205 let expected_claim = numerator_eval * lambda + denominator_eval;
207 if round_proof.sumcheck_proof.claimed_sum != expected_claim {
208 return Err(LogupGkrVerificationError::InconsistentSumcheckClaim(i));
209 }
210 partially_verify_sumcheck_proof(
212 &round_proof.sumcheck_proof,
213 challenger,
214 i + number_of_interaction_variables as usize + 1,
215 3,
216 )?;
217 let (point, final_eval) = round_proof.sumcheck_proof.point_and_eval.clone();
219 let eq_eval = Mle::full_lagrange_eval(&point, &eval_point);
220 let numerator_sumcheck_eval = round_proof.numerator_0 * round_proof.denominator_1
221 + round_proof.numerator_1 * round_proof.denominator_0;
222 let denominator_sumcheck_eval = round_proof.denominator_0 * round_proof.denominator_1;
223 let expected_final_eval =
224 eq_eval * (numerator_sumcheck_eval * lambda + denominator_sumcheck_eval);
225 if final_eval != expected_final_eval {
226 return Err(LogupGkrVerificationError::InconsistentEvaluation(i));
227 }
228
229 challenger.observe_ext_element(round_proof.numerator_0);
231 challenger.observe_ext_element(round_proof.numerator_1);
232 challenger.observe_ext_element(round_proof.denominator_0);
233 challenger.observe_ext_element(round_proof.denominator_1);
234
235 eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
237 let last_coordinate = challenger.sample_ext_element::<GC::EF>();
239 eval_point.add_dimension_back(last_coordinate);
240 numerator_eval = round_proof.numerator_0
242 + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
243 denominator_eval = round_proof.denominator_0
244 + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
245 }
246
247 let (interaction_point, trace_point) =
249 eval_point.split_at(number_of_interaction_variables as usize);
250 let trace_variables = trace_point.dimension();
252 if trace_variables != max_log_row_count {
253 return Err(LogupGkrVerificationError::InvalidLastLayerDimension(
254 trace_variables,
255 max_log_row_count,
256 ));
257 }
258
259 let LogUpEvaluations { point, chip_openings } = logup_evaluations;
261 if point != &trace_point {
262 return Err(LogupGkrVerificationError::TracePointMismatch);
263 }
264
265 let betas = partial_lagrange_blocking(&beta_seed);
266
267 let mut numerator_values = Vec::with_capacity(num_of_interactions);
270 let mut denominator_values = Vec::with_capacity(num_of_interactions);
271 let mut point_extended = point.clone();
272 point_extended.add_dimension(GC::EF::zero());
273 let len = shard_chips.len();
274 challenger.observe(GC::F::from_canonical_usize(len));
275 for ((chip, openings), threshold) in
276 shard_chips.iter().zip_eq(chip_openings.values()).zip_eq(degrees.values())
277 {
278 if let Some(prep_eval) = openings.preprocessed_trace_evaluations.as_ref() {
280 challenger.observe_variable_length_extension_slice(prep_eval);
281 if prep_eval.evaluations().sizes() != [chip.air.preprocessed_width()] {
282 return Err(LogupGkrVerificationError::InvalidShape);
283 }
284 } else if chip.air.preprocessed_width() != 0 {
285 return Err(LogupGkrVerificationError::InvalidShape);
286 }
287 challenger.observe_variable_length_extension_slice(&openings.main_trace_evaluations);
288 if openings.main_trace_evaluations.evaluations().sizes() != [chip.air.width()] {
289 return Err(LogupGkrVerificationError::InvalidShape);
290 }
291
292 if threshold.dimension() != point_extended.dimension() {
293 return Err(LogupGkrVerificationError::InvalidShape);
294 }
295
296 let geq_eval = full_geq(threshold, &point_extended);
297 let ChipEvaluation { main_trace_evaluations, preprocessed_trace_evaluations } =
298 openings;
299 for (interaction, is_send) in chip
300 .sends()
301 .iter()
302 .map(|s| (s, true))
303 .chain(chip.receives().iter().map(|r| (r, false)))
304 {
305 let (real_numerator, real_denominator) = interaction.eval(
306 preprocessed_trace_evaluations.as_ref(),
307 main_trace_evaluations,
308 alpha,
309 betas.as_slice(),
310 );
311 let padding_trace_opening =
312 MleEval::from(vec![GC::EF::zero(); main_trace_evaluations.num_polynomials()]);
313 let padding_preprocessed_opening = preprocessed_trace_evaluations
314 .as_ref()
315 .map(|eval| MleEval::from(vec![GC::EF::zero(); eval.num_polynomials()]));
316 let (padding_numerator, padding_denominator) = interaction.eval(
317 padding_preprocessed_opening.as_ref(),
318 &padding_trace_opening,
319 alpha,
320 betas.as_slice(),
321 );
322
323 let numerator_eval = real_numerator - padding_numerator * geq_eval;
324 let denominator_eval =
325 real_denominator + (GC::EF::one() - padding_denominator) * geq_eval;
326 let numerator_eval = if is_send { numerator_eval } else { -numerator_eval };
327 numerator_values.push(numerator_eval);
328 denominator_values.push(denominator_eval);
329 }
330 }
331 numerator_values.resize(1 << interaction_point.dimension(), GC::EF::zero());
334 let numerator = Mle::from(numerator_values);
335 denominator_values.resize(1 << interaction_point.dimension(), GC::EF::one());
337 let denominator = Mle::from(denominator_values);
338
339 let expected_numerator_eval = numerator.blocking_eval_at(&interaction_point)[0];
340 let expected_denominator_eval = denominator.blocking_eval_at(&interaction_point)[0];
341 if numerator_eval != expected_numerator_eval {
342 return Err(LogupGkrVerificationError::NumeratorEvaluationMismatch(
343 numerator_eval,
344 expected_numerator_eval,
345 ));
346 }
347 if denominator_eval != expected_denominator_eval {
348 return Err(LogupGkrVerificationError::DenominatorEvaluationMismatch(
349 denominator_eval,
350 expected_denominator_eval,
351 ));
352 }
353 Ok(())
354 }
355}