Skip to main content

sp1_hypercube/logup_gkr/
prover.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use slop_algebra::AbstractField;
4use slop_alloc::{CanCopyFromRef, CpuBackend, ToHost};
5use slop_challenger::{
6    CanObserve, FieldChallenger, GrindingChallenger, IopCtx, VariableLengthChallenger,
7};
8use slop_multilinear::{Mle, MultilinearPcsChallenger, Point};
9
10use crate::{
11    air::MachineAir, prove_gkr_round, prover::Traces, Chip, ChipEvaluation, LogupGkrCpuCircuit,
12    LogupGkrCpuTraceGenerator, ShardContext, GKR_GRINDING_BITS,
13};
14
15use super::{LogUpEvaluations, LogUpGkrOutput, LogupGkrProof, LogupGkrRoundProof};
16
17/// TODO
18pub struct GkrProverImpl<GC: IopCtx, SC: ShardContext<GC>> {
19    /// TODO
20    trace_generator: LogupGkrCpuTraceGenerator<GC::F, GC::EF, SC::Air>,
21}
22
23/// TODO
24impl<GC: IopCtx, SC: ShardContext<GC>> GkrProverImpl<GC, SC> {
25    /// TODO
26    #[must_use]
27    pub fn new(trace_generator: LogupGkrCpuTraceGenerator<GC::F, GC::EF, SC::Air>) -> Self {
28        Self { trace_generator }
29    }
30
31    /// TODO
32    pub fn prove_gkr_circuit(
33        &self,
34        numerator_value: GC::EF,
35        denominator_value: GC::EF,
36        eval_point: Point<GC::EF>,
37        mut circuit: LogupGkrCpuCircuit<GC::F, GC::EF>,
38        challenger: &mut GC::Challenger,
39    ) -> (Point<GC::EF>, Vec<LogupGkrRoundProof<GC::EF>>) {
40        let mut round_proofs = Vec::new();
41        // Follow the GKR protocol layer by layer.
42        let mut numerator_eval = numerator_value;
43        let mut denominator_eval = denominator_value;
44        let mut eval_point = eval_point;
45        while let Some(layer) = circuit.next_layer() {
46            let round_proof =
47                prove_gkr_round(layer, &eval_point, numerator_eval, denominator_eval, challenger);
48            // Observe the prover message.
49            challenger.observe_ext_element(round_proof.numerator_0);
50            challenger.observe_ext_element(round_proof.numerator_1);
51            challenger.observe_ext_element(round_proof.denominator_0);
52            challenger.observe_ext_element(round_proof.denominator_1);
53            // Get the evaluation point for the claims of the next round.
54            eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
55            // Sample the last coordinate.
56            let last_coordinate = challenger.sample_ext_element::<GC::EF>();
57            // Compute the evaluation of the numerator and denominator at the last coordinate.
58            numerator_eval = round_proof.numerator_0
59                + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
60            denominator_eval = round_proof.denominator_0
61                + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
62            eval_point.add_dimension_back(last_coordinate);
63            // Add the round proof to the total
64            round_proofs.push(round_proof);
65        }
66        (eval_point, round_proofs)
67    }
68
69    #[allow(clippy::too_many_arguments)]
70    pub(crate) fn prove_logup_gkr(
71        &self,
72        chips: &BTreeSet<Chip<GC::F, SC::Air>>,
73        preprocessed_traces: &Traces<GC::F, CpuBackend>,
74        traces: &Traces<GC::F, CpuBackend>,
75        public_values: Vec<GC::F>,
76        challenger: &mut GC::Challenger,
77    ) -> LogupGkrProof<<GC::Challenger as GrindingChallenger>::Witness, GC::EF> {
78        let max_interaction_arity = chips
79            .iter()
80            .flat_map(|c| c.sends().iter().chain(c.receives().iter()))
81            .map(|i| i.values.len() + 1)
82            .max()
83            .unwrap();
84        let beta_seed_dim = max_interaction_arity.next_power_of_two().ilog2();
85
86        let witness = challenger.grind(GKR_GRINDING_BITS);
87
88        // Sample the logup challenges.
89        let alpha = challenger.sample_ext_element::<GC::EF>();
90        let beta_seed = (0..beta_seed_dim)
91            .map(|_| challenger.sample_ext_element::<GC::EF>())
92            .collect::<Point<_>>();
93        let _pv_challenge = challenger.sample_ext_element::<GC::EF>();
94
95        let num_interactions =
96            chips.iter().map(|chip| chip.sends().len() + chip.receives().len()).sum::<usize>();
97        let num_interaction_variables = num_interactions.next_power_of_two().ilog2();
98
99        #[cfg(sp1_debug_constraints)]
100        {
101            use crate::{
102                air::InteractionScope, debug_interactions_with_all_chips, InteractionKind,
103            };
104            use slop_alloc::CanCopyIntoRef;
105
106            let mut host_preprocessed_traces = BTreeMap::new();
107
108            for (name, preprocessed_trace) in preprocessed_traces.iter() {
109                let host_preprocessed_trace =
110                    CpuBackend::copy_to_dst(&CpuBackend, preprocessed_trace).unwrap();
111                host_preprocessed_traces.insert(name.clone(), host_preprocessed_trace);
112            }
113
114            let mut host_traces = BTreeMap::new();
115            for (name, trace) in traces.iter() {
116                let host_trace = CpuBackend::copy_to_dst(&CpuBackend, trace).unwrap();
117                host_traces.insert(name.clone(), host_trace);
118            }
119
120            let host_traces = Traces { named_traces: host_traces };
121
122            let host_preprocessed_traces = Traces { named_traces: host_preprocessed_traces };
123
124            debug_interactions_with_all_chips::<GC::F, SC::Air>(
125                &chips.iter().cloned().collect::<Vec<_>>(),
126                &host_preprocessed_traces,
127                &host_traces,
128                public_values.clone(),
129                InteractionKind::all_kinds(),
130                InteractionScope::Local,
131            );
132        }
133
134        // Run the GKR circuit and get the output.
135        let (output, circuit) = {
136            let _span = tracing::debug_span!("generate GKR circuit").entered();
137            self.trace_generator.generate_gkr_circuit(
138                chips,
139                preprocessed_traces.clone(),
140                traces.clone(),
141                public_values,
142                alpha,
143                beta_seed,
144            )
145        };
146
147        let LogUpGkrOutput { numerator, denominator } = &output;
148
149        let host_numerator = numerator.to_host().unwrap();
150        let host_denominator = denominator.to_host().unwrap();
151
152        challenger.observe_variable_length_extension_slice(host_numerator.guts().as_slice());
153        challenger.observe_variable_length_extension_slice(host_denominator.guts().as_slice());
154        let output_host =
155            LogUpGkrOutput { numerator: host_numerator, denominator: host_denominator };
156
157        // TODO: instead calculate from number of interactions.
158        let initial_number_of_variables = numerator.num_variables();
159        assert_eq!(initial_number_of_variables, num_interaction_variables + 1);
160        let first_eval_point = challenger.sample_point::<GC::EF>(initial_number_of_variables);
161
162        // Follow the GKR protocol layer by layer.
163        let first_point = numerator.backend().copy_to(&first_eval_point).unwrap();
164        let first_point_eq = Mle::partial_lagrange(&first_point);
165        let first_numerator_eval = numerator.eval_at_eq(&first_point_eq).to_host().unwrap()[0];
166        let first_denominator_eval = denominator.eval_at_eq(&first_point_eq).to_host().unwrap()[0];
167
168        let (eval_point, round_proofs) = {
169            let _span = tracing::debug_span!("prove GKR circuit").entered();
170            self.prove_gkr_circuit(
171                first_numerator_eval,
172                first_denominator_eval,
173                first_eval_point,
174                circuit,
175                challenger,
176            )
177        };
178
179        // Get the evaluations for each chip at the evaluation point of the last round.
180        let mut chip_evaluations = BTreeMap::new();
181
182        let trace_dimension = traces.values().next().unwrap().num_variables();
183        let eval_point = eval_point.last_k(trace_dimension as usize);
184        let eval_point_b = numerator.backend().copy_to(&eval_point).unwrap();
185        let eval_point_eq = Mle::partial_lagrange(&eval_point_b);
186
187        challenger.observe(GC::F::from_canonical_usize(chips.len()));
188        for chip in chips.iter() {
189            let name = chip.name();
190            let main_trace = traces.get(name).unwrap();
191            let preprocessed_trace = preprocessed_traces.get(name);
192
193            let main_evaluation = main_trace.eval_at_eq(&eval_point, &eval_point_eq);
194            let preprocessed_evaluation =
195                preprocessed_trace.as_ref().map(|t| t.eval_at_eq(&eval_point, &eval_point_eq));
196            let main_evaluation = main_evaluation.to_host().unwrap();
197            let preprocessed_evaluation = preprocessed_evaluation.map(|e| e.to_host().unwrap());
198            let openings = ChipEvaluation {
199                main_trace_evaluations: main_evaluation,
200                preprocessed_trace_evaluations: preprocessed_evaluation,
201            };
202            // Observe the openings.
203            if let Some(prep_eval) = openings.preprocessed_trace_evaluations.as_ref() {
204                challenger.observe_variable_length_extension_slice(prep_eval);
205            }
206            challenger.observe_variable_length_extension_slice(&openings.main_trace_evaluations);
207
208            chip_evaluations.insert(name.to_string(), openings);
209        }
210
211        let logup_evaluations =
212            LogUpEvaluations { point: eval_point, chip_openings: chip_evaluations };
213
214        LogupGkrProof { circuit_output: output_host, round_proofs, logup_evaluations, witness }
215    }
216}