Skip to main content

sp1_hypercube/logup_gkr/
cpu.rs

1use std::{
2    collections::{BTreeMap, BTreeSet},
3    marker::PhantomData,
4    sync::Arc,
5};
6
7use slop_algebra::{ExtensionField, Field};
8use slop_alloc::CpuBackend;
9use slop_challenger::FieldChallenger;
10use slop_multilinear::{Mle, PaddedMle, Point};
11use slop_sumcheck::reduce_sumcheck_to_evaluation;
12
13use crate::{air::MachineAir, prover::Traces, Chip, LogupRoundPolynomial, PolynomialLayer};
14
15use super::LogUpGkrOutput;
16
17/// A trace generator for the GKR circuit.
18pub struct LogupGkrCpuTraceGenerator<F, EF, A>(PhantomData<(F, EF, A)>);
19
20impl<F, EF, A> Default for LogupGkrCpuTraceGenerator<F, EF, A> {
21    fn default() -> Self {
22        Self(PhantomData)
23    }
24}
25
26/// A trace generator for the GKR circuit.
27pub struct LogupGkrCpuCircuit<F: Field, EF> {
28    layers: Vec<GkrCircuitLayer<F, EF>>,
29}
30
31/// A layer of the GKR circuit.
32pub enum GkrCircuitLayer<F: Field, EF> {
33    /// An intermediate layer of the GKR circuit.
34    Layer(LogUpGkrCpuLayer<EF, EF>),
35    /// The first layer of the GKR circuit.
36    FirstLayer(LogUpGkrCpuLayer<F, EF>),
37}
38
39/// A layer of the GKR circuit.
40pub struct LogUpGkrCpuLayer<F, EF> {
41    /// The numerators of the layer (`PaddedMle<F>` per table with dimensions `num_row_variables` x
42    /// `num_interaction_variables`)
43    pub numerator_0: Vec<PaddedMle<F>>,
44    /// The denominators of the layer (`PaddedMle<EF>` per table with dimensions
45    /// `num_row_variables` x `num_interaction_variables`)
46    pub denominator_0: Vec<PaddedMle<EF>>,
47    /// The numerators of the layer (`PaddedMle<F>` per table with dimensions `num_row_variables` x
48    /// `num_interaction_variables`)
49    pub numerator_1: Vec<PaddedMle<F>>,
50    /// The denominators of the layer (`PaddedMle<EF>` per table with dimensions
51    /// `num_row_variables` x `num_interaction_variables`)
52    pub denominator_1: Vec<PaddedMle<EF>>,
53    /// The number of row variables (log height of each mle)
54    pub num_row_variables: usize,
55    /// The number of interaction variables (log width of each mle)
56    pub num_interaction_variables: usize,
57}
58
59/// An interaction layer of the GKR circuit (`num_row_variables` == 1).
60pub struct InteractionLayer<F, EF> {
61    /// The numerators of the layer (`PaddedMle<F>` per table with dimensions
62    /// `num_interaction_variables` x 1)
63    pub numerator_0: Arc<Mle<F>>,
64    /// The denominators of the layer (`PaddedMle<EF>` per table with dimensions
65    /// `num_interaction_variables` x 1)
66    pub denominator_0: Arc<Mle<EF>>,
67    /// The numerators of the layer (`PaddedMle<F>` per table with dimensions
68    /// `num_interaction_variables` x 1)
69    pub numerator_1: Arc<Mle<F>>,
70    /// The denominators of the layer (`PaddedMle<EF>` per table with dimensions
71    /// `num_interaction_variables` x 1)
72    pub denominator_1: Arc<Mle<EF>>,
73}
74
75impl<F: Field, EF: ExtensionField<F>, A: MachineAir<F>> LogupGkrCpuTraceGenerator<F, EF, A> {
76    #[allow(unused_variables)]
77    #[allow(clippy::needless_pass_by_value)]
78    pub(crate) fn generate_gkr_circuit(
79        &self,
80        chips: &BTreeSet<Chip<F, A>>,
81        preprocessed_traces: Traces<F, CpuBackend>,
82        traces: Traces<F, CpuBackend>,
83        public_values: Vec<F>,
84        alpha: EF,
85        beta_seed: Point<EF>,
86    ) -> (LogUpGkrOutput<EF>, LogupGkrCpuCircuit<F, EF>) {
87        let interactions = chips
88            .iter()
89            .map(|chip| {
90                let interactions = chip
91                    .sends()
92                    .iter()
93                    .map(|int| (int, true))
94                    .chain(chip.receives().iter().map(|int| (int, false)))
95                    .collect::<Vec<_>>();
96                (chip.name().to_string(), interactions)
97            })
98            .collect::<BTreeMap<_, _>>();
99
100        let first_layer = self.generate_first_layer(
101            &interactions,
102            &traces,
103            &preprocessed_traces,
104            alpha,
105            beta_seed,
106        );
107        let num_row_variables = first_layer.num_row_variables;
108        // println!("num_row_variables: {:?}", num_row_variables);
109        let num_interaction_variables = first_layer.num_interaction_variables;
110        let mut layers = Vec::new();
111        layers.push(GkrCircuitLayer::FirstLayer(first_layer));
112
113        for _ in 0..num_row_variables - 1 {
114            let next_layer = match layers.last().unwrap() {
115                GkrCircuitLayer::Layer(layer) => self.layer_transition(layer),
116                GkrCircuitLayer::FirstLayer(layer) => self.layer_transition(layer),
117            };
118            layers.push(GkrCircuitLayer::Layer(next_layer));
119        }
120
121        let last_layer = layers.last().unwrap();
122        let last_layer = match last_layer {
123            GkrCircuitLayer::Layer(layer) => layer,
124            GkrCircuitLayer::FirstLayer(layer) => unreachable!(),
125        };
126        assert_eq!(last_layer.num_row_variables, 1);
127        let output = self.extract_outputs(last_layer);
128
129        let circuit_generator = Some(Self::default());
130        let circuit = LogupGkrCpuCircuit { layers };
131
132        (output, circuit)
133    }
134}
135
136impl<F: Field, EF: ExtensionField<F>> Iterator for LogupGkrCpuCircuit<F, EF> {
137    type Item = GkrCircuitLayer<F, EF>;
138
139    fn next(&mut self) -> Option<Self::Item> {
140        self.layers.pop()
141    }
142}
143
144/// Basic information about the GKR circuit.
145impl<F: Field, EF: ExtensionField<F>> LogupGkrCpuCircuit<F, EF> {
146    pub(crate) fn next_layer(&mut self) -> Option<GkrCircuitLayer<F, EF>> {
147        self.layers.pop()
148    }
149}
150
151pub(crate) fn prove_gkr_round<F: Field, EF: ExtensionField<F>, Challenger: FieldChallenger<F>>(
152    circuit: GkrCircuitLayer<F, EF>,
153    eval_point: &slop_multilinear::Point<EF>,
154    numerator_eval: EF,
155    denominator_eval: EF,
156    challenger: &mut Challenger,
157) -> super::LogupGkrRoundProof<EF> {
158    let lambda = challenger.sample_ext_element::<EF>();
159
160    let (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof) = match circuit {
161        GkrCircuitLayer::Layer(layer) => {
162            let (interaction_point, row_point) =
163                eval_point.split_at(layer.num_interaction_variables);
164            let eq_interaction = Mle::partial_lagrange(&interaction_point);
165            let eq_row = Mle::partial_lagrange(&row_point);
166            let sumcheck_poly = LogupRoundPolynomial {
167                layer: PolynomialLayer::CircuitLayer(layer),
168                eq_row: Arc::new(eq_row),
169                eq_interaction: Arc::new(eq_interaction),
170                lambda,
171                eq_adjustment: EF::one(),
172                padding_adjustment: EF::one(),
173                point: eval_point.clone(),
174            };
175            let claim = numerator_eval * lambda + denominator_eval;
176
177            let (sumcheck_proof, mut openings) = reduce_sumcheck_to_evaluation(
178                vec![sumcheck_poly],
179                challenger,
180                vec![claim],
181                1,
182                lambda,
183            );
184
185            let openings = openings.pop().unwrap();
186            let [numerator_0, denominator_0, numerator_1, denominator_1] =
187                openings.try_into().unwrap();
188            (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof)
189        }
190        GkrCircuitLayer::FirstLayer(layer) => {
191            let (interaction_point, row_point) =
192                eval_point.split_at(layer.num_interaction_variables);
193            let eq_interaction = Mle::partial_lagrange(&interaction_point);
194            let eq_row = Mle::partial_lagrange(&row_point);
195            let sumcheck_poly = LogupRoundPolynomial {
196                layer: PolynomialLayer::CircuitLayer(layer),
197                eq_row: Arc::new(eq_row),
198                eq_interaction: Arc::new(eq_interaction),
199                lambda,
200                eq_adjustment: EF::one(),
201                padding_adjustment: EF::one(),
202                point: eval_point.clone(),
203            };
204            let claim = numerator_eval * lambda + denominator_eval;
205            let (sumcheck_proof, mut openings) = reduce_sumcheck_to_evaluation(
206                vec![sumcheck_poly],
207                challenger,
208                vec![claim],
209                1,
210                lambda,
211            );
212            let openings = openings.pop().unwrap();
213            let [numerator_0, denominator_0, numerator_1, denominator_1] =
214                openings.try_into().unwrap();
215            (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof)
216        }
217    };
218
219    super::LogupGkrRoundProof {
220        numerator_0,
221        numerator_1,
222        denominator_0,
223        denominator_1,
224        sumcheck_proof,
225    }
226}