sp1_recursion_program/fri/
two_adic_pcs.rs

1use p3_commit::TwoAdicMultiplicativeCoset;
2use p3_field::{AbstractField, TwoAdicField};
3use p3_symmetric::Hash;
4use sp1_primitives::types::RecursionProgramType;
5use sp1_recursion_compiler::prelude::*;
6use sp1_recursion_core::runtime::DIGEST_SIZE;
7
8use super::{
9    types::{
10        DigestVariable, DimensionsVariable, FriConfigVariable, TwoAdicPcsMatsVariable,
11        TwoAdicPcsProofVariable, TwoAdicPcsRoundVariable,
12    },
13    verify_batch, verify_challenges, verify_shape_and_sample_challenges,
14    TwoAdicMultiplicativeCosetVariable,
15};
16use crate::{
17    challenger::{DuplexChallengerVariable, FeltChallenger},
18    commit::PcsVariable,
19};
20
21pub fn verify_two_adic_pcs<C: Config>(
22    builder: &mut Builder<C>,
23    config: &FriConfigVariable<C>,
24    rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
25    proof: TwoAdicPcsProofVariable<C>,
26    challenger: &mut DuplexChallengerVariable<C>,
27) where
28    C::F: TwoAdicField,
29    C::EF: TwoAdicField,
30{
31    let mut input_ptr = builder.array::<FriFoldInput<_>>(1);
32    let g = builder.generator();
33
34    let log_blowup = config.log_blowup;
35    let blowup = config.blowup;
36    let alpha = challenger.sample_ext(builder);
37
38    builder.cycle_tracker("stage-d-1-verify-shape-and-sample-challenges");
39    let fri_challenges =
40        verify_shape_and_sample_challenges(builder, config, &proof.fri_proof, challenger);
41    builder.cycle_tracker("stage-d-1-verify-shape-and-sample-challenges");
42
43    let commit_phase_commits_len = proof.fri_proof.commit_phase_commits.len().materialize(builder);
44    let log_global_max_height: Var<_> = builder.eval(commit_phase_commits_len + log_blowup);
45
46    let mut reduced_openings: Array<C, Array<C, Ext<C::F, C::EF>>> =
47        builder.array(proof.query_openings.len());
48
49    builder.cycle_tracker("stage-d-2-fri-fold");
50    builder.range(0, proof.query_openings.len()).for_each(|i, builder| {
51        let query_opening = builder.get(&proof.query_openings, i);
52        let index_bits = builder.get(&fri_challenges.query_indices, i);
53
54        let mut ro: Array<C, Ext<C::F, C::EF>> = builder.array(32);
55        let mut alpha_pow: Array<C, Ext<C::F, C::EF>> = builder.array(32);
56        let zero_ef = builder.eval(C::EF::zero().cons());
57        for j in 0..32 {
58            builder.set_value(&mut ro, j, zero_ef);
59        }
60        let one_ef = builder.eval(C::EF::one().cons());
61        for j in 0..32 {
62            builder.set_value(&mut alpha_pow, j, one_ef);
63        }
64
65        builder.range(0, rounds.len()).for_each(|j, builder| {
66            let batch_opening = builder.get(&query_opening, j);
67            let round = builder.get(&rounds, j);
68            let batch_commit = round.batch_commit;
69            let mats = round.mats;
70
71            let mut batch_heights_log2: Array<C, Var<C::N>> = builder.array(mats.len());
72            builder.range(0, mats.len()).for_each(|k, builder| {
73                let mat = builder.get(&mats, k);
74                let height_log2: Var<_> = builder.eval(mat.domain.log_n + log_blowup);
75                builder.set_value(&mut batch_heights_log2, k, height_log2);
76            });
77            let mut batch_dims: Array<C, DimensionsVariable<C>> = builder.array(mats.len());
78            builder.range(0, mats.len()).for_each(|k, builder| {
79                let mat = builder.get(&mats, k);
80                let dim =
81                    DimensionsVariable::<C> { height: builder.eval(mat.domain.size() * blowup) };
82                builder.set_value(&mut batch_dims, k, dim);
83            });
84
85            let log_batch_max_height = builder.get(&batch_heights_log2, 0);
86            let bits_reduced: Var<_> = builder.eval(log_global_max_height - log_batch_max_height);
87            let index_bits_shifted_v1 = index_bits.shift(builder, bits_reduced);
88            verify_batch::<C, 1>(
89                builder,
90                &batch_commit,
91                batch_dims,
92                index_bits_shifted_v1,
93                batch_opening.opened_values.clone(),
94                &batch_opening.opening_proof,
95            );
96
97            builder.range(0, batch_opening.opened_values.len()).for_each(|k, builder| {
98                let mat_opening = builder.get(&batch_opening.opened_values, k);
99                let mat = builder.get(&mats, k);
100                let mat_points = mat.points;
101                let mat_values = mat.values;
102
103                let log2_domain_size = mat.domain.log_n;
104                let log_height: Var<C::N> = builder.eval(log2_domain_size + log_blowup);
105
106                let bits_reduced: Var<C::N> = builder.eval(log_global_max_height - log_height);
107                let index_bits_shifted = index_bits.shift(builder, bits_reduced);
108
109                let two_adic_generator = config.get_two_adic_generator(builder, log_height);
110                builder.cycle_tracker("exp_reverse_bits_len");
111
112                let two_adic_generator_exp: Felt<C::F> =
113                    if matches!(builder.program_type, RecursionProgramType::Wrap) {
114                        builder.exp_reverse_bits_len(
115                            two_adic_generator,
116                            &index_bits_shifted,
117                            log_height,
118                        )
119                    } else {
120                        builder.exp_reverse_bits_len_fast(
121                            two_adic_generator,
122                            &index_bits_shifted,
123                            log_height,
124                        )
125                    };
126
127                builder.cycle_tracker("exp_reverse_bits_len");
128                let x: Felt<C::F> = builder.eval(two_adic_generator_exp * g);
129
130                builder.range(0, mat_points.len()).for_each(|l, builder| {
131                    let z: Ext<C::F, C::EF> = builder.get(&mat_points, l);
132                    let ps_at_z = builder.get(&mat_values, l);
133                    let input = FriFoldInput {
134                        z,
135                        alpha,
136                        x,
137                        log_height,
138                        mat_opening: mat_opening.clone(),
139                        ps_at_z: ps_at_z.clone(),
140                        alpha_pow: alpha_pow.clone(),
141                        ro: ro.clone(),
142                    };
143                    builder.set_value(&mut input_ptr, 0, input);
144
145                    let ps_at_z_len = ps_at_z.len().materialize(builder);
146                    builder.push(DslIr::FriFold(ps_at_z_len, input_ptr.clone()));
147                });
148            });
149        });
150
151        builder.set_value(&mut reduced_openings, i, ro);
152    });
153    builder.cycle_tracker("stage-d-2-fri-fold");
154
155    builder.cycle_tracker("stage-d-3-verify-challenges");
156    verify_challenges(builder, config, &proof.fri_proof, &fri_challenges, &reduced_openings);
157    builder.cycle_tracker("stage-d-3-verify-challenges");
158}
159
160impl<C: Config> FromConstant<C> for TwoAdicPcsRoundVariable<C>
161where
162    C::F: TwoAdicField,
163{
164    type Constant = (
165        Hash<C::F, C::F, DIGEST_SIZE>,
166        Vec<(TwoAdicMultiplicativeCoset<C::F>, Vec<(C::EF, Vec<C::EF>)>)>,
167    );
168
169    fn constant(value: Self::Constant, builder: &mut Builder<C>) -> Self {
170        let (commit_val, domains_and_openings_val) = value;
171
172        // Allocate the commitment.
173        let mut commit = builder.dyn_array::<Felt<_>>(DIGEST_SIZE);
174        let commit_val: [C::F; DIGEST_SIZE] = commit_val.into();
175        for (i, f) in commit_val.into_iter().enumerate() {
176            builder.set(&mut commit, i, f);
177        }
178
179        let mut mats =
180            builder.dyn_array::<TwoAdicPcsMatsVariable<C>>(domains_and_openings_val.len());
181
182        for (i, (domain, openning)) in domains_and_openings_val.into_iter().enumerate() {
183            let domain = builder.constant::<TwoAdicMultiplicativeCosetVariable<_>>(domain);
184
185            let points_val = openning.iter().map(|(p, _)| *p).collect::<Vec<_>>();
186            let values_val = openning.iter().map(|(_, v)| v.clone()).collect::<Vec<_>>();
187            let mut points: Array<_, Ext<_, _>> = builder.dyn_array(points_val.len());
188            for (j, point) in points_val.into_iter().enumerate() {
189                let el: Ext<_, _> = builder.eval(point.cons());
190                builder.set_value(&mut points, j, el);
191            }
192            let mut values: Array<_, Array<_, Ext<_, _>>> = builder.dyn_array(values_val.len());
193            for (j, val) in values_val.into_iter().enumerate() {
194                let mut tmp = builder.dyn_array(val.len());
195                for (k, v) in val.into_iter().enumerate() {
196                    let el: Ext<_, _> = builder.eval(v.cons());
197                    builder.set_value(&mut tmp, k, el);
198                }
199                builder.set_value(&mut values, j, tmp);
200            }
201
202            let mat = TwoAdicPcsMatsVariable { domain, points, values };
203            builder.set_value(&mut mats, i, mat);
204        }
205
206        Self { batch_commit: commit, mats }
207    }
208}
209
210#[derive(DslVariable, Clone)]
211pub struct TwoAdicFriPcsVariable<C: Config> {
212    pub config: FriConfigVariable<C>,
213}
214
215impl<C: Config> PcsVariable<C, DuplexChallengerVariable<C>> for TwoAdicFriPcsVariable<C>
216where
217    C::F: TwoAdicField,
218    C::EF: TwoAdicField,
219{
220    type Domain = TwoAdicMultiplicativeCosetVariable<C>;
221
222    type Commitment = DigestVariable<C>;
223
224    type Proof = TwoAdicPcsProofVariable<C>;
225
226    fn natural_domain_for_log_degree(
227        &self,
228        builder: &mut Builder<C>,
229        log_degree: Usize<C::N>,
230    ) -> Self::Domain {
231        self.config.get_subgroup(builder, log_degree)
232    }
233
234    fn verify(
235        &self,
236        builder: &mut Builder<C>,
237        rounds: Array<C, TwoAdicPcsRoundVariable<C>>,
238        proof: Self::Proof,
239        challenger: &mut DuplexChallengerVariable<C>,
240    ) {
241        verify_two_adic_pcs(builder, &self.config, rounds, proof, challenger)
242    }
243}
244
245pub mod tests {
246
247    use std::{cmp::Reverse, collections::VecDeque};
248
249    use crate::{
250        challenger::{CanObserveVariable, DuplexChallengerVariable, FeltChallenger},
251        commit::PcsVariable,
252        fri::{
253            types::TwoAdicPcsRoundVariable, TwoAdicFriPcsVariable,
254            TwoAdicMultiplicativeCosetVariable,
255        },
256        hints::Hintable,
257        utils::const_fri_config,
258    };
259    use itertools::Itertools;
260    use p3_baby_bear::BabyBear;
261    use p3_challenger::{CanObserve, FieldChallenger};
262    use p3_commit::{Pcs, TwoAdicMultiplicativeCoset};
263    use p3_field::AbstractField;
264    use p3_matrix::dense::RowMajorMatrix;
265    use rand::rngs::OsRng;
266
267    use sp1_recursion_compiler::{
268        config::InnerConfig,
269        ir::{Array, Builder, Usize, Var},
270    };
271    use sp1_recursion_core::{
272        air::Block,
273        runtime::{RecursionProgram, DIGEST_SIZE},
274    };
275    use sp1_stark::{
276        baby_bear_poseidon2::compressed_fri_config, inner_perm, InnerChallenge, InnerChallenger,
277        InnerCompress, InnerDft, InnerHash, InnerPcs, InnerPcsProof, InnerVal, InnerValMmcs,
278    };
279
280    pub fn build_test_fri_with_cols_and_log2_rows(
281        nb_cols: usize,
282        nb_log2_rows: usize,
283    ) -> (RecursionProgram<BabyBear>, VecDeque<Vec<Block<BabyBear>>>) {
284        let mut rng = &mut OsRng;
285        let log_degrees = &[nb_log2_rows];
286        let perm = inner_perm();
287        let fri_config = compressed_fri_config();
288        let hash = InnerHash::new(perm.clone());
289        let compress = InnerCompress::new(perm.clone());
290        let val_mmcs = InnerValMmcs::new(hash, compress);
291        let dft = InnerDft {};
292        let pcs_val: InnerPcs =
293            InnerPcs::new(log_degrees.iter().copied().max().unwrap(), dft, val_mmcs, fri_config);
294
295        // Generate proof.
296        let domains_and_polys = log_degrees
297            .iter()
298            .map(|&d| {
299                (
300                    <InnerPcs as Pcs<InnerChallenge, InnerChallenger>>::natural_domain_for_degree(
301                        &pcs_val,
302                        1 << d,
303                    ),
304                    RowMajorMatrix::<InnerVal>::rand(&mut rng, 1 << d, nb_cols),
305                )
306            })
307            .sorted_by_key(|(dom, _)| Reverse(dom.log_n))
308            .collect::<Vec<_>>();
309        let (commit, data) = <InnerPcs as Pcs<InnerChallenge, InnerChallenger>>::commit(
310            &pcs_val,
311            domains_and_polys.clone(),
312        );
313        let mut challenger = InnerChallenger::new(perm.clone());
314        challenger.observe(commit);
315        let zeta = challenger.sample_ext_element::<InnerChallenge>();
316        let points = domains_and_polys.iter().map(|_| vec![zeta]).collect::<Vec<_>>();
317        let (opening, proof) = pcs_val.open(vec![(&data, points)], &mut challenger);
318
319        // Verify proof.
320        let mut challenger = InnerChallenger::new(perm.clone());
321        challenger.observe(commit);
322        challenger.sample_ext_element::<InnerChallenge>();
323        let os: Vec<(
324            TwoAdicMultiplicativeCoset<InnerVal>,
325            Vec<(InnerChallenge, Vec<InnerChallenge>)>,
326        )> = domains_and_polys
327            .iter()
328            .zip(&opening[0])
329            .map(|((domain, _), mat_openings)| (*domain, vec![(zeta, mat_openings[0].clone())]))
330            .collect();
331        pcs_val.verify(vec![(commit, os.clone())], &proof, &mut challenger).unwrap();
332
333        // Test the recursive Pcs.
334        let mut builder = Builder::<InnerConfig>::default();
335        let config = const_fri_config(&mut builder, &compressed_fri_config());
336        let pcs = TwoAdicFriPcsVariable { config };
337        let rounds =
338            builder.constant::<Array<_, TwoAdicPcsRoundVariable<_>>>(vec![(commit, os.clone())]);
339
340        // Test natural domain for degree.
341        for log_d_val in log_degrees.iter() {
342            let log_d: Var<_> = builder.eval(InnerVal::from_canonical_usize(*log_d_val));
343            let domain = pcs.natural_domain_for_log_degree(&mut builder, Usize::Var(log_d));
344
345            let domain_val =
346                <InnerPcs as Pcs<InnerChallenge, InnerChallenger>>::natural_domain_for_degree(
347                    &pcs_val,
348                    1 << log_d_val,
349                );
350
351            let expected_domain: TwoAdicMultiplicativeCosetVariable<_> =
352                builder.constant(domain_val);
353
354            builder.assert_eq::<TwoAdicMultiplicativeCosetVariable<_>>(domain, expected_domain);
355        }
356
357        // Test proof verification.
358        let proofvar = InnerPcsProof::read(&mut builder);
359        let mut challenger = DuplexChallengerVariable::new(&mut builder);
360        let commit = <[InnerVal; DIGEST_SIZE]>::from(commit).to_vec();
361        let commit = builder.constant::<Array<_, _>>(commit);
362        challenger.observe(&mut builder, commit);
363        challenger.sample_ext(&mut builder);
364        pcs.verify(&mut builder, rounds, proofvar, &mut challenger);
365        builder.halt();
366
367        let program = builder.compile_program();
368        let mut witness_stream = VecDeque::new();
369        witness_stream.extend(proof.write());
370        (program, witness_stream)
371    }
372
373    #[test]
374    fn test_two_adic_fri_pcs_single_batch() {
375        use sp1_recursion_core::stark::utils::{run_test_recursion, TestConfig};
376        let (program, witness) = build_test_fri_with_cols_and_log2_rows(10, 16);
377
378        // We don't test with the config TestConfig::WideDeg17Wrap, since it doesn't have the
379        // `ExpReverseBitsLen` chip.
380        run_test_recursion(program.clone(), Some(witness.clone()), TestConfig::WideDeg3);
381        run_test_recursion(program, Some(witness), TestConfig::SkinnyDeg7);
382    }
383}