1use std::{
2 collections::{BTreeMap, BTreeSet},
3 marker::PhantomData,
4 sync::Arc,
5};
6
7use slop_algebra::{ExtensionField, Field};
8use slop_alloc::CpuBackend;
9use slop_challenger::FieldChallenger;
10use slop_multilinear::{Mle, PaddedMle, Point};
11use slop_sumcheck::reduce_sumcheck_to_evaluation;
12
13use crate::{air::MachineAir, prover::Traces, Chip, LogupRoundPolynomial, PolynomialLayer};
14
15use super::LogUpGkrOutput;
16
17pub struct LogupGkrCpuTraceGenerator<F, EF, A>(PhantomData<(F, EF, A)>);
19
20impl<F, EF, A> Default for LogupGkrCpuTraceGenerator<F, EF, A> {
21 fn default() -> Self {
22 Self(PhantomData)
23 }
24}
25
26pub struct LogupGkrCpuCircuit<F: Field, EF> {
28 layers: Vec<GkrCircuitLayer<F, EF>>,
29}
30
31pub enum GkrCircuitLayer<F: Field, EF> {
33 Layer(LogUpGkrCpuLayer<EF, EF>),
35 FirstLayer(LogUpGkrCpuLayer<F, EF>),
37}
38
39pub struct LogUpGkrCpuLayer<F, EF> {
41 pub numerator_0: Vec<PaddedMle<F>>,
44 pub denominator_0: Vec<PaddedMle<EF>>,
47 pub numerator_1: Vec<PaddedMle<F>>,
50 pub denominator_1: Vec<PaddedMle<EF>>,
53 pub num_row_variables: usize,
55 pub num_interaction_variables: usize,
57}
58
59pub struct InteractionLayer<F, EF> {
61 pub numerator_0: Arc<Mle<F>>,
64 pub denominator_0: Arc<Mle<EF>>,
67 pub numerator_1: Arc<Mle<F>>,
70 pub denominator_1: Arc<Mle<EF>>,
73}
74
75impl<F: Field, EF: ExtensionField<F>, A: MachineAir<F>> LogupGkrCpuTraceGenerator<F, EF, A> {
76 #[allow(unused_variables)]
77 #[allow(clippy::needless_pass_by_value)]
78 pub(crate) fn generate_gkr_circuit(
79 &self,
80 chips: &BTreeSet<Chip<F, A>>,
81 preprocessed_traces: Traces<F, CpuBackend>,
82 traces: Traces<F, CpuBackend>,
83 public_values: Vec<F>,
84 alpha: EF,
85 beta_seed: Point<EF>,
86 ) -> (LogUpGkrOutput<EF>, LogupGkrCpuCircuit<F, EF>) {
87 let interactions = chips
88 .iter()
89 .map(|chip| {
90 let interactions = chip
91 .sends()
92 .iter()
93 .map(|int| (int, true))
94 .chain(chip.receives().iter().map(|int| (int, false)))
95 .collect::<Vec<_>>();
96 (chip.name().to_string(), interactions)
97 })
98 .collect::<BTreeMap<_, _>>();
99
100 let first_layer = self.generate_first_layer(
101 &interactions,
102 &traces,
103 &preprocessed_traces,
104 alpha,
105 beta_seed,
106 );
107 let num_row_variables = first_layer.num_row_variables;
108 let num_interaction_variables = first_layer.num_interaction_variables;
110 let mut layers = Vec::new();
111 layers.push(GkrCircuitLayer::FirstLayer(first_layer));
112
113 for _ in 0..num_row_variables - 1 {
114 let next_layer = match layers.last().unwrap() {
115 GkrCircuitLayer::Layer(layer) => self.layer_transition(layer),
116 GkrCircuitLayer::FirstLayer(layer) => self.layer_transition(layer),
117 };
118 layers.push(GkrCircuitLayer::Layer(next_layer));
119 }
120
121 let last_layer = layers.last().unwrap();
122 let last_layer = match last_layer {
123 GkrCircuitLayer::Layer(layer) => layer,
124 GkrCircuitLayer::FirstLayer(layer) => unreachable!(),
125 };
126 assert_eq!(last_layer.num_row_variables, 1);
127 let output = self.extract_outputs(last_layer);
128
129 let circuit_generator = Some(Self::default());
130 let circuit = LogupGkrCpuCircuit { layers };
131
132 (output, circuit)
133 }
134}
135
136impl<F: Field, EF: ExtensionField<F>> Iterator for LogupGkrCpuCircuit<F, EF> {
137 type Item = GkrCircuitLayer<F, EF>;
138
139 fn next(&mut self) -> Option<Self::Item> {
140 self.layers.pop()
141 }
142}
143
144impl<F: Field, EF: ExtensionField<F>> LogupGkrCpuCircuit<F, EF> {
146 pub(crate) fn next_layer(&mut self) -> Option<GkrCircuitLayer<F, EF>> {
147 self.layers.pop()
148 }
149}
150
151pub(crate) fn prove_gkr_round<F: Field, EF: ExtensionField<F>, Challenger: FieldChallenger<F>>(
152 circuit: GkrCircuitLayer<F, EF>,
153 eval_point: &slop_multilinear::Point<EF>,
154 numerator_eval: EF,
155 denominator_eval: EF,
156 challenger: &mut Challenger,
157) -> super::LogupGkrRoundProof<EF> {
158 let lambda = challenger.sample_ext_element::<EF>();
159
160 let (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof) = match circuit {
161 GkrCircuitLayer::Layer(layer) => {
162 let (interaction_point, row_point) =
163 eval_point.split_at(layer.num_interaction_variables);
164 let eq_interaction = Mle::partial_lagrange(&interaction_point);
165 let eq_row = Mle::partial_lagrange(&row_point);
166 let sumcheck_poly = LogupRoundPolynomial {
167 layer: PolynomialLayer::CircuitLayer(layer),
168 eq_row: Arc::new(eq_row),
169 eq_interaction: Arc::new(eq_interaction),
170 lambda,
171 eq_adjustment: EF::one(),
172 padding_adjustment: EF::one(),
173 point: eval_point.clone(),
174 };
175 let claim = numerator_eval * lambda + denominator_eval;
176
177 let (sumcheck_proof, mut openings) = reduce_sumcheck_to_evaluation(
178 vec![sumcheck_poly],
179 challenger,
180 vec![claim],
181 1,
182 lambda,
183 );
184
185 let openings = openings.pop().unwrap();
186 let [numerator_0, denominator_0, numerator_1, denominator_1] =
187 openings.try_into().unwrap();
188 (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof)
189 }
190 GkrCircuitLayer::FirstLayer(layer) => {
191 let (interaction_point, row_point) =
192 eval_point.split_at(layer.num_interaction_variables);
193 let eq_interaction = Mle::partial_lagrange(&interaction_point);
194 let eq_row = Mle::partial_lagrange(&row_point);
195 let sumcheck_poly = LogupRoundPolynomial {
196 layer: PolynomialLayer::CircuitLayer(layer),
197 eq_row: Arc::new(eq_row),
198 eq_interaction: Arc::new(eq_interaction),
199 lambda,
200 eq_adjustment: EF::one(),
201 padding_adjustment: EF::one(),
202 point: eval_point.clone(),
203 };
204 let claim = numerator_eval * lambda + denominator_eval;
205 let (sumcheck_proof, mut openings) = reduce_sumcheck_to_evaluation(
206 vec![sumcheck_poly],
207 challenger,
208 vec![claim],
209 1,
210 lambda,
211 );
212 let openings = openings.pop().unwrap();
213 let [numerator_0, denominator_0, numerator_1, denominator_1] =
214 openings.try_into().unwrap();
215 (numerator_0, denominator_0, numerator_1, denominator_1, sumcheck_proof)
216 }
217 };
218
219 super::LogupGkrRoundProof {
220 numerator_0,
221 numerator_1,
222 denominator_0,
223 denominator_1,
224 sumcheck_proof,
225 }
226}