Skip to main content

sp1_hypercube/logup_gkr/
prover.rs

1use std::collections::{BTreeMap, BTreeSet};
2
3use slop_algebra::AbstractField;
4use slop_alloc::{CanCopyFromRef, CpuBackend, ToHost};
5use slop_challenger::{CanObserve, FieldChallenger, IopCtx, VariableLengthChallenger};
6use slop_multilinear::{Mle, MultilinearPcsChallenger, Point};
7
8use crate::{
9    air::MachineAir, prove_gkr_round, prover::Traces, Chip, ChipEvaluation, LogupGkrCpuCircuit,
10    LogupGkrCpuTraceGenerator, ShardContext,
11};
12
13use super::{LogUpEvaluations, LogUpGkrOutput, LogupGkrProof, LogupGkrRoundProof};
14
15/// TODO
16pub struct GkrProverImpl<GC: IopCtx, SC: ShardContext<GC>> {
17    /// TODO
18    trace_generator: LogupGkrCpuTraceGenerator<GC::F, GC::EF, SC::Air>,
19}
20
21/// TODO
22impl<GC: IopCtx, SC: ShardContext<GC>> GkrProverImpl<GC, SC> {
23    /// TODO
24    #[must_use]
25    pub fn new(trace_generator: LogupGkrCpuTraceGenerator<GC::F, GC::EF, SC::Air>) -> Self {
26        Self { trace_generator }
27    }
28
29    /// TODO
30    pub fn prove_gkr_circuit(
31        &self,
32        numerator_value: GC::EF,
33        denominator_value: GC::EF,
34        eval_point: Point<GC::EF>,
35        mut circuit: LogupGkrCpuCircuit<GC::F, GC::EF>,
36        challenger: &mut GC::Challenger,
37    ) -> (Point<GC::EF>, Vec<LogupGkrRoundProof<GC::EF>>) {
38        let mut round_proofs = Vec::new();
39        // Follow the GKR protocol layer by layer.
40        let mut numerator_eval = numerator_value;
41        let mut denominator_eval = denominator_value;
42        let mut eval_point = eval_point;
43        while let Some(layer) = circuit.next_layer() {
44            let round_proof =
45                prove_gkr_round(layer, &eval_point, numerator_eval, denominator_eval, challenger);
46            // Observe the prover message.
47            challenger.observe_ext_element(round_proof.numerator_0);
48            challenger.observe_ext_element(round_proof.numerator_1);
49            challenger.observe_ext_element(round_proof.denominator_0);
50            challenger.observe_ext_element(round_proof.denominator_1);
51            // Get the evaluation point for the claims of the next round.
52            eval_point = round_proof.sumcheck_proof.point_and_eval.0.clone();
53            // Sample the last coordinate.
54            let last_coordinate = challenger.sample_ext_element::<GC::EF>();
55            // Compute the evaluation of the numerator and denominator at the last coordinate.
56            numerator_eval = round_proof.numerator_0
57                + (round_proof.numerator_1 - round_proof.numerator_0) * last_coordinate;
58            denominator_eval = round_proof.denominator_0
59                + (round_proof.denominator_1 - round_proof.denominator_0) * last_coordinate;
60            eval_point.add_dimension_back(last_coordinate);
61            // Add the round proof to the total
62            round_proofs.push(round_proof);
63        }
64        (eval_point, round_proofs)
65    }
66
67    #[allow(clippy::too_many_arguments)]
68    pub(crate) fn prove_logup_gkr(
69        &self,
70        chips: &BTreeSet<Chip<GC::F, SC::Air>>,
71        preprocessed_traces: &Traces<GC::F, CpuBackend>,
72        traces: &Traces<GC::F, CpuBackend>,
73        public_values: Vec<GC::F>,
74        alpha: GC::EF,
75        beta_seed: Point<GC::EF>,
76        challenger: &mut GC::Challenger,
77    ) -> LogupGkrProof<GC::EF> {
78        let num_interactions =
79            chips.iter().map(|chip| chip.sends().len() + chip.receives().len()).sum::<usize>();
80        let num_interaction_variables = num_interactions.next_power_of_two().ilog2();
81
82        #[cfg(sp1_debug_constraints)]
83        {
84            use crate::{
85                air::InteractionScope, debug_interactions_with_all_chips, InteractionKind,
86            };
87            use slop_alloc::CanCopyIntoRef;
88
89            let mut host_preprocessed_traces = BTreeMap::new();
90
91            for (name, preprocessed_trace) in preprocessed_traces.iter() {
92                let host_preprocessed_trace =
93                    CpuBackend::copy_to_dst(&CpuBackend, preprocessed_trace).unwrap();
94                host_preprocessed_traces.insert(name.clone(), host_preprocessed_trace);
95            }
96
97            let mut host_traces = BTreeMap::new();
98            for (name, trace) in traces.iter() {
99                let host_trace = CpuBackend::copy_to_dst(&CpuBackend, trace).unwrap();
100                host_traces.insert(name.clone(), host_trace);
101            }
102
103            let host_traces = Traces { named_traces: host_traces };
104
105            let host_preprocessed_traces = Traces { named_traces: host_preprocessed_traces };
106
107            debug_interactions_with_all_chips::<GC::F, SC::Air>(
108                &chips.iter().cloned().collect::<Vec<_>>(),
109                &host_preprocessed_traces,
110                &host_traces,
111                public_values.clone(),
112                InteractionKind::all_kinds(),
113                InteractionScope::Local,
114            );
115        }
116
117        // Run the GKR circuit and get the output.
118        let (output, circuit) = {
119            let _span = tracing::info_span!("generate GKR circuit").entered();
120            self.trace_generator.generate_gkr_circuit(
121                chips,
122                preprocessed_traces.clone(),
123                traces.clone(),
124                public_values,
125                alpha,
126                beta_seed,
127            )
128        };
129
130        let LogUpGkrOutput { numerator, denominator } = &output;
131
132        let host_numerator = numerator.to_host().unwrap();
133        let host_denominator = denominator.to_host().unwrap();
134
135        challenger.observe_variable_length_extension_slice(host_numerator.guts().as_slice());
136        challenger.observe_variable_length_extension_slice(host_denominator.guts().as_slice());
137        let output_host =
138            LogUpGkrOutput { numerator: host_numerator, denominator: host_denominator };
139
140        // TODO: instead calculate from number of interactions.
141        let initial_number_of_variables = numerator.num_variables();
142        assert_eq!(initial_number_of_variables, num_interaction_variables + 1);
143        let first_eval_point = challenger.sample_point::<GC::EF>(initial_number_of_variables);
144
145        // Follow the GKR protocol layer by layer.
146        let first_point = numerator.backend().copy_to(&first_eval_point).unwrap();
147        let first_point_eq = Mle::partial_lagrange(&first_point);
148        let first_numerator_eval = numerator.eval_at_eq(&first_point_eq).to_host().unwrap()[0];
149        let first_denominator_eval = denominator.eval_at_eq(&first_point_eq).to_host().unwrap()[0];
150
151        let (eval_point, round_proofs) = {
152            let _span = tracing::info_span!("prove GKR circuit").entered();
153            self.prove_gkr_circuit(
154                first_numerator_eval,
155                first_denominator_eval,
156                first_eval_point,
157                circuit,
158                challenger,
159            )
160        };
161
162        // Get the evaluations for each chip at the evaluation point of the last round.
163        let mut chip_evaluations = BTreeMap::new();
164
165        let trace_dimension = traces.values().next().unwrap().num_variables();
166        let eval_point = eval_point.last_k(trace_dimension as usize);
167        let eval_point_b = numerator.backend().copy_to(&eval_point).unwrap();
168        let eval_point_eq = Mle::partial_lagrange(&eval_point_b);
169
170        challenger.observe(GC::F::from_canonical_usize(chips.len()));
171        for chip in chips.iter() {
172            let name = chip.name();
173            let main_trace = traces.get(name).unwrap();
174            let preprocessed_trace = preprocessed_traces.get(name);
175
176            let main_evaluation = main_trace.eval_at_eq(&eval_point, &eval_point_eq);
177            let preprocessed_evaluation =
178                preprocessed_trace.as_ref().map(|t| t.eval_at_eq(&eval_point, &eval_point_eq));
179            let main_evaluation = main_evaluation.to_host().unwrap();
180            let preprocessed_evaluation = preprocessed_evaluation.map(|e| e.to_host().unwrap());
181            let openings = ChipEvaluation {
182                main_trace_evaluations: main_evaluation,
183                preprocessed_trace_evaluations: preprocessed_evaluation,
184            };
185            // Observe the openings.
186            if let Some(prep_eval) = openings.preprocessed_trace_evaluations.as_ref() {
187                challenger.observe_variable_length_extension_slice(prep_eval);
188            }
189            challenger.observe_variable_length_extension_slice(&openings.main_trace_evaluations);
190
191            chip_evaluations.insert(name.to_string(), openings);
192        }
193
194        let logup_evaluations =
195            LogUpEvaluations { point: eval_point, chip_openings: chip_evaluations };
196
197        LogupGkrProof { circuit_output: output_host, round_proofs, logup_evaluations }
198    }
199}