Skip to main content

sp1_hypercube/logup_gkr/
verifier.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    marker::PhantomData,
4};
5
6use itertools::Itertools;
7use slop_algebra::{ExtensionField, Field};
8use slop_challenger::{FieldChallenger, VariableLengthChallenger};
9use slop_multilinear::{
10    full_geq, partial_lagrange_blocking, Mle, MleEval, MultilinearPcsChallenger, Point,
11};
12use slop_sumcheck::{partially_verify_sumcheck_proof, SumcheckError};
13use thiserror::Error;
14
15use crate::{air::MachineAir, Chip};
16
17use super::{ChipEvaluation, LogUpEvaluations, LogUpGkrOutput, LogupGkrProof};
18
19/// An error type for `LogUp` GKR.
20#[derive(Debug, Error)]
21pub enum LogupGkrVerificationError<EF> {
22    /// The sumcheck claim is not consistent with the calculated one from the prover messages.
23    #[error("inconsistent sumcheck claim at round {0}")]
24    InconsistentSumcheckClaim(usize),
25    /// Inconsistency between the calculated evaluation and the sumcheck evaluation.
26    #[error("inconsistent evaluation at round {0}")]
27    InconsistentEvaluation(usize),
28    /// Error when verifying sumcheck proof.
29    #[error("sumcheck error: {0}")]
30    SumcheckError(#[from] SumcheckError),
31    /// The proof shape does not match the expected one for the given number of interactions.
32    #[error("invalid shape")]
33    InvalidShape,
34    /// The size of the first layer does not match the expected one.
35    #[error("invalid first layer dimension: {0} != {1}")]
36    InvalidFirstLayerDimension(u32, u32),
37    /// The dimension of the last layer does not match the expected one.
38    #[error("invalid last layer dimension: {0} != {1}")]
39    InvalidLastLayerDimension(usize, usize),
40    /// The trace point does not match the claimed opening point.
41    #[error("trace point mismatch")]
42    TracePointMismatch,
43    /// The cumulative sum does not match the claimed one.
44    #[error("cumulative sum mismatch: {0} != {1}")]
45    CumulativeSumMismatch(EF, EF),
46    /// The numerator evaluation does not match the expected one.
47    #[error("numerator evaluation mismatch: {0} != {1}")]
48    NumeratorEvaluationMismatch(EF, EF),
49    /// The denominator evaluation does not match the expected one.
50    #[error("denominator evaluation mismatch: {0} != {1}")]
51    DenominatorEvaluationMismatch(EF, EF),
52    /// The denominator guts had zero in it.
53    #[error("denominator evaluation has zero value")]
54    ZeroDenominator,
55}
56
57/// Verifier for `LogUp` GKR.
58#[derive(Clone, Debug, Copy, Default, PartialEq, Eq, Hash)]
59pub struct LogUpGkrVerifier<F, EF, A>(PhantomData<(F, EF, A)>);
60
61impl<F, EF, A> LogUpGkrVerifier<F, EF, A>
62where
63    F: Field,
64    EF: ExtensionField<F>,
65    A: MachineAir<F>,
66{
67    /// Verify the `LogUp` GKR proof.
68    ///
69    /// # Errors
70    #[allow(clippy::too_many_arguments)]
71    #[allow(clippy::too_many_lines)]
72    pub fn verify_logup_gkr(
73        shard_chips: &BTreeSet<Chip<F, A>>,
74        degrees: &BTreeMap<String, Point<F>>,
75        alpha: EF,
76        beta_seed: &Point<EF>,
77        cumulative_sum: EF,
78        max_log_row_count: usize,
79        proof: &LogupGkrProof<EF>,
80        challenger: &mut impl FieldChallenger<F>,
81    ) -> Result<(), LogupGkrVerificationError<EF>> {
82        let LogupGkrProof { circuit_output, round_proofs, logup_evaluations } = proof;
83
84        let LogUpGkrOutput { numerator, denominator } = circuit_output;
85
86        // Calculate the interaction number.
87        let num_of_interactions =
88            shard_chips.iter().map(|c| c.sends().len() + c.receives().len()).sum::<usize>();
89        let number_of_interaction_variables = num_of_interactions.next_power_of_two().ilog2();
90
91        let expected_size = 1 << (number_of_interaction_variables + 1);
92
93        if numerator.guts().dimensions.sizes() != [expected_size, 1]
94            || denominator.guts().dimensions.sizes() != [expected_size, 1]
95        {
96            return Err(LogupGkrVerificationError::InvalidShape);
97        }
98
99        // Observe the output claims.
100        challenger.observe_variable_length_extension_slice(numerator.guts().as_slice());
101        challenger.observe_variable_length_extension_slice(denominator.guts().as_slice());
102
103        if denominator.guts().as_slice().iter().any(slop_algebra::Field::is_zero) {
104            return Err(LogupGkrVerificationError::ZeroDenominator);
105        }
106
107        // Verify that the cumulative sum matches the claimed one.
108        let output_cumulative_sum = numerator
109            .guts()
110            .as_slice()
111            .iter()
112            .zip_eq(denominator.guts().as_slice().iter())
113            .map(|(n, d)| *n / *d)
114            .sum::<EF>();
115        if output_cumulative_sum != cumulative_sum {
116            return Err(LogupGkrVerificationError::CumulativeSumMismatch(
117                output_cumulative_sum,
118                cumulative_sum,
119            ));
120        }
121
122        // Assert that the size of the first layer matches the expected one.
123        let initial_number_of_variables = numerator.num_variables();
124        if initial_number_of_variables != number_of_interaction_variables + 1 {
125            return Err(LogupGkrVerificationError::InvalidFirstLayerDimension(
126                initial_number_of_variables,
127                number_of_interaction_variables + 1,
128            ));
129        }
130        // Sample the first evaluation point.
131        let first_eval_point = challenger.sample_point::<EF>(initial_number_of_variables);
132
133        // Follow the GKR protocol layer by layer.
134        let mut numerator_eval = numerator.blocking_eval_at(&first_eval_point)[0];
135        let mut denominator_eval = denominator.blocking_eval_at(&first_eval_point)[0];
136        let mut eval_point = first_eval_point;
137
138        if round_proofs.len() + 1 != max_log_row_count {
139            return Err(LogupGkrVerificationError::InvalidShape);
140        }
141
142        for (i, round_proof) in round_proofs.iter().enumerate() {
143            // Get the batching challenge for combining the claims.
144            let lambda = challenger.sample_ext_element::<EF>();
145            // Check that the claimed sum is consistent with the previous round values.
146            let expected_claim = numerator_eval * lambda + denominator_eval;
147            if round_proof.sumcheck_proof.claimed_sum != expected_claim {
148                return Err(LogupGkrVerificationError::InconsistentSumcheckClaim(i));
149            }
150            // Verify the sumcheck proof.
151            partially_verify_sumcheck_proof(
152                &round_proof.sumcheck_proof,
153                challenger,
154                i + number_of_interaction_variables as usize + 1,
155                3,
156            )?;
157            // Verify that the evaluation claim is consistent with the prover messages.
158            let (point, final_eval) = round_proof.sumcheck_proof.point_and_eval.clone();
159            let eq_eval = Mle::full_lagrange_eval(&point, &eval_point);
160            let numerator_sumcheck_eval = round_proof.numerator_0 * round_proof.denominator_1
161                + round_proof.numerator_1 * round_proof.denominator_0;
162            let denominator_sumcheck_eval = round_proof.denominator_0 * round_proof.denominator_1;
163            let expected_final_eval =
164                eq_eval * (numerator_sumcheck_eval * lambda + denominator_sumcheck_eval);
165            if final_eval != expected_final_eval {
166                return Err(LogupGkrVerificationError::InconsistentEvaluation(i));
167            }
168
169            // Observe the prover message.
170            challenger.observe_ext_element(round_proof.numerator_0);
171            challenger.observe_ext_element(round_proof.numerator_1);
172            challenger.observe_ext_element(round_proof.denominator_0);
173            challenger.observe_ext_element(round_proof.denominator_1);
174
175            // Get the evaluation point for the claims of the next round.
176            eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
177            // Sample the last coordinate and add to the point.
178            let last_coordinate = challenger.sample_ext_element::<EF>();
179            eval_point.add_dimension_back(last_coordinate);
180            // Update the evaluation of the numerator and denominator at the last coordinate.
181            numerator_eval = round_proof.numerator_0
182                + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
183            denominator_eval = round_proof.denominator_0
184                + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
185        }
186
187        // Verify that the last layer evaluations are consistent with the evaluations of the traces.
188        let (interaction_point, trace_point) =
189            eval_point.split_at(number_of_interaction_variables as usize);
190        // Assert that the number of trace variables matches the expected one.
191        let trace_variables = trace_point.dimension();
192        if trace_variables != max_log_row_count {
193            return Err(LogupGkrVerificationError::InvalidLastLayerDimension(
194                trace_variables,
195                max_log_row_count,
196            ));
197        }
198
199        // Assert that the trace point is the same as the claimed opening point
200        let LogUpEvaluations { point, chip_openings } = logup_evaluations;
201        if point != &trace_point {
202            return Err(LogupGkrVerificationError::TracePointMismatch);
203        }
204
205        let betas = partial_lagrange_blocking(beta_seed);
206
207        // Compute the expected opening of the last layer numerator and denominator values from the
208        // trace openings.
209        let mut numerator_values = Vec::with_capacity(num_of_interactions);
210        let mut denominator_values = Vec::with_capacity(num_of_interactions);
211        let mut point_extended = point.clone();
212        point_extended.add_dimension(EF::zero());
213        let len = shard_chips.len();
214        challenger.observe(F::from_canonical_usize(len));
215        for ((chip, openings), threshold) in
216            shard_chips.iter().zip_eq(chip_openings.values()).zip_eq(degrees.values())
217        {
218            // Observe the opening
219            if let Some(prep_eval) = openings.preprocessed_trace_evaluations.as_ref() {
220                challenger.observe_variable_length_extension_slice(prep_eval);
221                if prep_eval.evaluations().sizes() != [chip.air.preprocessed_width()] {
222                    return Err(LogupGkrVerificationError::InvalidShape);
223                }
224            } else if chip.air.preprocessed_width() != 0 {
225                return Err(LogupGkrVerificationError::InvalidShape);
226            }
227            challenger.observe_variable_length_extension_slice(&openings.main_trace_evaluations);
228            if openings.main_trace_evaluations.evaluations().sizes() != [chip.air.width()] {
229                return Err(LogupGkrVerificationError::InvalidShape);
230            }
231
232            if threshold.dimension() != point_extended.dimension() {
233                return Err(LogupGkrVerificationError::InvalidShape);
234            }
235
236            let geq_eval = full_geq(threshold, &point_extended);
237            let ChipEvaluation { main_trace_evaluations, preprocessed_trace_evaluations } =
238                openings;
239            for (interaction, is_send) in chip
240                .sends()
241                .iter()
242                .map(|s| (s, true))
243                .chain(chip.receives().iter().map(|r| (r, false)))
244            {
245                let (real_numerator, real_denominator) = interaction.eval(
246                    preprocessed_trace_evaluations.as_ref(),
247                    main_trace_evaluations,
248                    alpha,
249                    betas.as_slice(),
250                );
251                let padding_trace_opening =
252                    MleEval::from(vec![EF::zero(); main_trace_evaluations.num_polynomials()]);
253                let padding_preprocessed_opening = preprocessed_trace_evaluations
254                    .as_ref()
255                    .map(|eval| MleEval::from(vec![EF::zero(); eval.num_polynomials()]));
256                let (padding_numerator, padding_denominator) = interaction.eval(
257                    padding_preprocessed_opening.as_ref(),
258                    &padding_trace_opening,
259                    alpha,
260                    betas.as_slice(),
261                );
262
263                let numerator_eval = real_numerator - padding_numerator * geq_eval;
264                let denominator_eval =
265                    real_denominator + (EF::one() - padding_denominator) * geq_eval;
266                let numerator_eval = if is_send { numerator_eval } else { -numerator_eval };
267                numerator_values.push(numerator_eval);
268                denominator_values.push(denominator_eval);
269            }
270        }
271        // Convert the values to a multilinear polynomials.
272        // Pad the numerator values with zeros.
273        numerator_values.resize(1 << interaction_point.dimension(), EF::zero());
274        let numerator = Mle::from(numerator_values);
275        // Pad the denominator values with ones.
276        denominator_values.resize(1 << interaction_point.dimension(), EF::one());
277        let denominator = Mle::from(denominator_values);
278
279        let expected_numerator_eval = numerator.blocking_eval_at(&interaction_point)[0];
280        let expected_denominator_eval = denominator.blocking_eval_at(&interaction_point)[0];
281        if numerator_eval != expected_numerator_eval {
282            return Err(LogupGkrVerificationError::NumeratorEvaluationMismatch(
283                numerator_eval,
284                expected_numerator_eval,
285            ));
286        }
287        if denominator_eval != expected_denominator_eval {
288            return Err(LogupGkrVerificationError::DenominatorEvaluationMismatch(
289                denominator_eval,
290                expected_denominator_eval,
291            ));
292        }
293        Ok(())
294    }
295}