1use serde::{Deserialize, Serialize};
2use slop_algebra::{Field, TwoAdicField};
3use slop_alloc::{CpuBackend, ToHost};
4use slop_basefold_prover::{BasefoldProver, BasefoldProverData, BasefoldProverError};
5use slop_challenger::IopCtx;
6use slop_commit::{Message, Rounds};
7use slop_merkle_tree::ComputeTcsOpenings;
8use slop_multilinear::{Evaluations, Mle, MleEval, MultilinearPcsProver, Point, ToMle};
9use std::fmt::Debug;
10
11use crate::{interleave_multilinears_with_fixed_rate, StackedBasefoldProof};
12
13#[derive(Clone)]
14pub struct StackedPcsProver<P: ComputeTcsOpenings<GC, CpuBackend>, GC: IopCtx<F: TwoAdicField>> {
15 basefold_prover: BasefoldProver<GC, P>,
16 pub log_stacking_height: u32,
17 pub batch_size: usize,
18 _marker: std::marker::PhantomData<GC>,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
22pub struct StackedBasefoldProverData<M, F, TcsProverData> {
23 pcs_batch_data: BasefoldProverData<F, TcsProverData>,
24 pub interleaved_mles: Message<M>,
25}
26
27impl<F: Field, PD> ToMle<F> for StackedBasefoldProverData<Mle<F>, F, PD> {
28 fn interleaved_mles(&self) -> Message<Mle<F, CpuBackend>> {
29 self.interleaved_mles.clone()
30 }
31}
32
33impl<GC, P> StackedPcsProver<P, GC>
34where
35 GC: IopCtx<F: TwoAdicField, EF: TwoAdicField>,
36 P: ComputeTcsOpenings<GC, CpuBackend>,
37{
38 pub const fn new(
39 basefold_prover: BasefoldProver<GC, P>,
40 log_stacking_height: u32,
41 batch_size: usize,
42 ) -> Self {
43 Self { basefold_prover, log_stacking_height, batch_size, _marker: std::marker::PhantomData }
44 }
45
46 pub fn round_batch_evaluations(
47 &self,
48 stacked_point: &Point<GC::EF>,
49 prover_data: &StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>,
50 ) -> Evaluations<GC::EF> {
51 prover_data
52 .interleaved_mles
53 .iter()
54 .map(|mle| mle.eval_at(stacked_point))
55 .collect::<Evaluations<_, _>>()
56 }
57
58 #[allow(clippy::type_complexity)]
59 pub fn commit_multilinears(
60 &self,
61 multilinears: Message<Mle<GC::F>>,
62 ) -> Result<
63 (GC::Digest, StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>, usize),
64 BasefoldProverError<P::ProverError>,
65 > {
66 let next_multiple = multilinears
71 .iter()
72 .map(|mle| mle.num_non_zero_entries() * mle.num_polynomials())
73 .sum::<usize>()
74 .next_multiple_of(1 << self.log_stacking_height)
75 .max(1 << self.log_stacking_height);
77
78 let num_added_vals = next_multiple
79 - multilinears
80 .iter()
81 .map(|mle| mle.num_non_zero_entries() * mle.num_polynomials())
82 .sum::<usize>();
83
84 let interleaved_mles = interleave_multilinears_with_fixed_rate(
85 self.batch_size,
86 multilinears,
87 self.log_stacking_height,
88 );
89 let (commit, pcs_batch_data) =
90 self.basefold_prover.commit_mles(interleaved_mles.clone())?;
91 let prover_data = StackedBasefoldProverData { pcs_batch_data, interleaved_mles };
92
93 Ok((commit, prover_data, num_added_vals))
94 }
95}
96
97impl<GC: IopCtx<F: TwoAdicField, EF: TwoAdicField>, P: ComputeTcsOpenings<GC, CpuBackend>>
98 MultilinearPcsProver<GC, StackedBasefoldProof<GC>> for StackedPcsProver<P, GC>
99{
100 type ProverData = StackedBasefoldProverData<Mle<GC::F>, GC::F, P::ProverData>;
101
102 type ProverError = BasefoldProverError<P::ProverError>;
103
104 fn commit_multilinear(
105 &self,
106 mles: Message<Mle<<GC as IopCtx>::F>>,
107 ) -> Result<(<GC as IopCtx>::Digest, Self::ProverData, usize), Self::ProverError> {
108 self.commit_multilinears(mles)
109 }
110
111 fn prove_trusted_evaluation(
112 &self,
113 eval_point: Point<<GC as IopCtx>::EF>,
114 _evaluation_claim: <GC as IopCtx>::EF,
115 prover_data: Rounds<Self::ProverData>,
116 challenger: &mut <GC as IopCtx>::Challenger,
117 ) -> Result<StackedBasefoldProof<GC>, Self::ProverError> {
118 let (_, stack_point) =
119 eval_point.split_at(eval_point.dimension() - self.log_stacking_height as usize);
120 let batch_evaluations: Rounds<_> = prover_data
121 .iter()
122 .map(|data| self.round_batch_evaluations(&stack_point, data))
123 .collect();
124
125 let mut host_batch_evaluations = Rounds::new();
126 for round_evals in batch_evaluations.iter() {
127 let mut host_round_evals = vec![];
128 for eval in round_evals.iter() {
129 let host_eval = eval.to_host().unwrap();
130 host_round_evals.extend(host_eval);
131 }
132 let host_round_evals = Evaluations::new(vec![host_round_evals.into()]);
133 host_batch_evaluations.push(host_round_evals);
134 }
135 let (pcs_prover_data, mle_rounds): (Rounds<_>, Rounds<_>) = prover_data
136 .into_iter()
137 .map(|data| (data.pcs_batch_data, data.interleaved_mles))
138 .unzip();
139
140 let (_, stack_point) =
141 eval_point.split_at(eval_point.dimension() - self.log_stacking_height as usize);
142
143 let pcs_proof = self.basefold_prover.prove_untrusted_evaluations(
144 stack_point,
145 mle_rounds,
146 batch_evaluations,
147 pcs_prover_data,
148 challenger,
149 )?;
150
151 let host_batch_evaluations = host_batch_evaluations
152 .into_iter()
153 .map(|round| round.into_iter().flatten().collect::<MleEval<_>>())
154 .collect::<Rounds<_>>();
155
156 Ok(StackedBasefoldProof {
157 basefold_proof: pcs_proof,
158 batch_evaluations: host_batch_evaluations,
159 })
160 }
161
162 fn log_max_padding_amount(&self) -> u32 {
163 self.log_stacking_height
164 }
165}
166#[cfg(test)]
167mod tests {
168 use rand::thread_rng;
169 use slop_algebra::extension::BinomialExtensionField;
170 use slop_baby_bear::{baby_bear_poseidon2::BabyBearDegree4Duplex, BabyBear};
171 use slop_basefold::{BasefoldVerifier, FriConfig};
172 use slop_basefold_prover::BasefoldProver;
173 use slop_challenger::CanObserve;
174 use slop_merkle_tree::Poseidon2BabyBear16Prover;
175 use slop_tensor::Tensor;
176
177 use crate::StackedPcsVerifier;
178
179 use super::*;
180
181 #[test]
182 fn test_stacked_prover_with_fixed_rate_interleave() {
183 let log_stacking_height = 10;
184 let batch_size = 10;
185
186 type GC = BabyBearDegree4Duplex;
187 type Prover = BasefoldProver<GC, Poseidon2BabyBear16Prover>;
188 type EF = BinomialExtensionField<BabyBear, 4>;
189
190 let round_widths_and_log_heights = [vec![(1 << 10, 10), (1 << 4, 11), (496, 11)]];
191
192 let total_data_length = round_widths_and_log_heights
193 .iter()
194 .map(|dims| dims.iter().map(|&(w, log_h)| w << log_h).sum::<usize>())
195 .sum::<usize>();
196 let total_number_of_variables = total_data_length.next_power_of_two().ilog2();
197 assert_eq!(1 << total_number_of_variables, total_data_length);
198 let round_areas = round_widths_and_log_heights
199 .iter()
200 .map(|dims| {
201 dims.iter()
202 .map(|&(w, log_h)| w << log_h)
203 .sum::<usize>()
204 .next_multiple_of(1 << log_stacking_height)
205 })
206 .collect::<Vec<_>>();
207
208 let mut rng = thread_rng();
209 let round_mles = round_widths_and_log_heights
210 .iter()
211 .map(|dims| {
212 dims.iter()
213 .map(|&(w, log_h)| Mle::<BabyBear>::rand(&mut rng, w, log_h))
214 .collect::<Message<_>>()
215 })
216 .collect::<Rounds<_>>();
217
218 let pcs_verifier = BasefoldVerifier::<GC>::new(
219 FriConfig::default_fri_config(),
220 round_widths_and_log_heights.len(),
221 );
222 let pcs_prover = Prover::new(&pcs_verifier);
223
224 let verifier = StackedPcsVerifier::new(pcs_verifier, log_stacking_height);
225 let prover = StackedPcsProver::new(pcs_prover, log_stacking_height, batch_size);
226
227 let mut challenger = GC::default_challenger();
228 let mut commitments = vec![];
229 let mut prover_data = Rounds::new();
230 let mut batch_evaluations = Rounds::new();
231 let point = Point::<EF>::rand(&mut rng, total_number_of_variables);
232
233 let concat_mle: Vec<BabyBear> = round_mles
234 .iter()
235 .flat_map(|mles| mles.iter())
236 .flat_map(|mle| mle.guts().transpose().as_slice().to_vec())
237 .collect();
238
239 let concat_mle =
240 Mle::new(Tensor::from(concat_mle).reshape([1 << total_number_of_variables, 1]));
241
242 let concat_eval_claim = concat_mle.eval_at(&point)[0];
243
244 let (batch_point, stack_point) =
245 point.split_at(point.dimension() - log_stacking_height as usize);
246 for mles in round_mles.iter() {
247 let (commitment, data, _) = prover.commit_multilinears(mles.clone()).unwrap();
248 challenger.observe(commitment);
249 commitments.push(commitment);
250 let evaluations = prover.round_batch_evaluations(&stack_point, &data);
251 prover_data.push(data);
252 batch_evaluations.push(evaluations);
253 }
254
255 let batch_evaluations_mle =
257 batch_evaluations.iter().flatten().flatten().cloned().collect::<Mle<_>>();
258 let eval_claim = batch_evaluations_mle.eval_at(&batch_point)[0];
260
261 assert_eq!(concat_eval_claim, eval_claim);
262
263 let proof = prover
264 .prove_trusted_evaluation(point.clone(), eval_claim, prover_data, &mut challenger)
265 .unwrap();
266
267 let mut challenger = GC::default_challenger();
268 for commitment in commitments.iter() {
269 challenger.observe(*commitment);
270 }
271 verifier
272 .verify_trusted_evaluation(
273 &commitments,
274 &round_areas,
275 &point,
276 &proof,
277 eval_claim,
278 &mut challenger,
279 )
280 .unwrap();
281 }
282}