sp1_recursion_program/
constraints.rs

1use p3_air::Air;
2use p3_commit::LagrangeSelectors;
3use p3_field::{AbstractExtensionField, AbstractField, TwoAdicField};
4
5use sp1_recursion_compiler::{
6    ir::{Array, Felt},
7    prelude::{Builder, Config, Ext, ExtConst, SymbolicExt},
8};
9use sp1_stark::{
10    air::MachineAir, AirOpenedValues, MachineChip, StarkGenericConfig, PROOF_MAX_NUM_PVS,
11};
12
13use crate::{
14    commit::PolynomialSpaceVariable,
15    fri::TwoAdicMultiplicativeCosetVariable,
16    stark::{RecursiveVerifierConstraintFolder, StarkVerifier},
17    types::{ChipOpenedValuesVariable, ChipOpening},
18};
19
20impl<C: Config, SC: StarkGenericConfig> StarkVerifier<C, SC>
21where
22    SC: StarkGenericConfig<Val = C::F, Challenge = C::EF>,
23    C::F: TwoAdicField,
24{
25    fn eval_constrains<A>(
26        builder: &mut Builder<C>,
27        chip: &MachineChip<SC, A>,
28        opening: &ChipOpening<C>,
29        public_values: Array<C, Felt<C::F>>,
30        selectors: &LagrangeSelectors<Ext<C::F, C::EF>>,
31        alpha: Ext<C::F, C::EF>,
32        permutation_challenges: &[Ext<C::F, C::EF>],
33    ) -> Ext<C::F, C::EF>
34    where
35        A: for<'b> Air<RecursiveVerifierConstraintFolder<'b, C>>,
36    {
37        let mut unflatten = |v: &[Ext<C::F, C::EF>]| {
38            v.chunks_exact(SC::Challenge::D)
39                .map(|chunk| {
40                    builder.eval(
41                        chunk
42                            .iter()
43                            .enumerate()
44                            .map(|(e_i, &x)| x * C::EF::monomial(e_i).cons())
45                            .sum::<SymbolicExt<_, _>>(),
46                    )
47                })
48                .collect::<Vec<Ext<_, _>>>()
49        };
50        let perm_opening = AirOpenedValues {
51            local: unflatten(&opening.permutation.local),
52            next: unflatten(&opening.permutation.next),
53        };
54
55        let mut folder_pv = Vec::new();
56        for i in 0..PROOF_MAX_NUM_PVS {
57            folder_pv.push(builder.get(&public_values, i));
58        }
59
60        let mut folder = RecursiveVerifierConstraintFolder::<C> {
61            preprocessed: opening.preprocessed.view(),
62            main: opening.main.view(),
63            perm: perm_opening.view(),
64            perm_challenges: permutation_challenges,
65            cumulative_sum: opening.cumulative_sum,
66            public_values: &folder_pv,
67            is_first_row: selectors.is_first_row,
68            is_last_row: selectors.is_last_row,
69            is_transition: selectors.is_transition,
70            alpha,
71            accumulator: SymbolicExt::zero(),
72            _marker: std::marker::PhantomData,
73        };
74
75        chip.eval(&mut folder);
76        builder.eval(folder.accumulator)
77    }
78
79    fn recompute_quotient(
80        builder: &mut Builder<C>,
81        opening: &ChipOpening<C>,
82        qc_domains: Vec<TwoAdicMultiplicativeCosetVariable<C>>,
83        zeta: Ext<C::F, C::EF>,
84    ) -> Ext<C::F, C::EF> {
85        let zps = qc_domains
86            .iter()
87            .enumerate()
88            .map(|(i, domain)| {
89                qc_domains
90                    .iter()
91                    .enumerate()
92                    .filter(|(j, _)| *j != i)
93                    .map(|(_, other_domain)| {
94                        let first_point: Ext<_, _> = builder.eval(domain.first_point());
95                        other_domain.zp_at_point(builder, zeta)
96                            * other_domain.zp_at_point(builder, first_point).inverse()
97                    })
98                    .product::<SymbolicExt<_, _>>()
99            })
100            .collect::<Vec<SymbolicExt<_, _>>>()
101            .into_iter()
102            .map(|x| builder.eval(x))
103            .collect::<Vec<Ext<_, _>>>();
104
105        builder.eval(
106            opening
107                .quotient
108                .iter()
109                .enumerate()
110                .map(|(ch_i, ch)| {
111                    assert_eq!(ch.len(), C::EF::D);
112                    ch.iter()
113                        .enumerate()
114                        .map(|(e_i, &c)| zps[ch_i] * C::EF::monomial(e_i) * c)
115                        .sum::<SymbolicExt<_, _>>()
116                })
117                .sum::<SymbolicExt<_, _>>(),
118        )
119    }
120
121    /// Reference: [sp1_core_machine::stark::Verifier::verify_constraints]
122    pub fn verify_constraints<A>(
123        builder: &mut Builder<C>,
124        chip: &MachineChip<SC, A>,
125        opening: &ChipOpenedValuesVariable<C>,
126        public_values: Array<C, Felt<C::F>>,
127        trace_domain: TwoAdicMultiplicativeCosetVariable<C>,
128        qc_domains: Vec<TwoAdicMultiplicativeCosetVariable<C>>,
129        zeta: Ext<C::F, C::EF>,
130        alpha: Ext<C::F, C::EF>,
131        permutation_challenges: &[Ext<C::F, C::EF>],
132    ) where
133        A: MachineAir<C::F> + for<'a> Air<RecursiveVerifierConstraintFolder<'a, C>>,
134    {
135        let opening = ChipOpening::from_variable(builder, chip, opening);
136        let sels = trace_domain.selectors_at_point(builder, zeta);
137
138        let folded_constraints = Self::eval_constrains(
139            builder,
140            chip,
141            &opening,
142            public_values,
143            &sels,
144            alpha,
145            permutation_challenges,
146        );
147
148        let quotient: Ext<_, _> = Self::recompute_quotient(builder, &opening, qc_domains, zeta);
149
150        // Assert that the quotient times the zerofier is equal to the folded constraints.
151        builder.assert_ext_eq(folded_constraints * sels.inv_zeroifier, quotient);
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use itertools::{izip, Itertools};
158    use rand::{thread_rng, Rng};
159
160    use sp1_core_executor::Program;
161    use sp1_core_machine::{io::SP1Stdin, riscv::RiscvAir};
162    use sp1_recursion_core::stark::utils::{run_test_recursion, TestConfig};
163
164    use p3_challenger::{CanObserve, FieldChallenger};
165    use sp1_recursion_compiler::{asm::AsmBuilder, ir::Felt, prelude::ExtConst};
166
167    use p3_commit::{Pcs, PolynomialSpace};
168    use sp1_stark::{
169        baby_bear_poseidon2::BabyBearPoseidon2, Chip, Com, CpuProver, Dom, OpeningProof,
170        PcsProverData, SP1CoreOpts, ShardCommitment, ShardProof, StarkGenericConfig, StarkMachine,
171    };
172
173    use crate::stark::StarkVerifier;
174
175    #[allow(clippy::type_complexity)]
176    fn get_shard_data<'a, SC>(
177        machine: &'a StarkMachine<SC, RiscvAir<SC::Val>>,
178        proof: &'a ShardProof<SC>,
179        challenger: &mut SC::Challenger,
180    ) -> (
181        Vec<&'a Chip<SC::Val, RiscvAir<SC::Val>>>,
182        Vec<Dom<SC>>,
183        Vec<Vec<Dom<SC>>>,
184        Vec<SC::Challenge>,
185        SC::Challenge,
186        SC::Challenge,
187    )
188    where
189        SC: StarkGenericConfig + Default,
190        SC::Challenger: Clone,
191        OpeningProof<SC>: Send + Sync,
192        Com<SC>: Send + Sync,
193        PcsProverData<SC>: Send + Sync,
194        SC::Val: p3_field::PrimeField32,
195    {
196        let ShardProof { commitment, opened_values, .. } = proof;
197
198        let ShardCommitment { permutation_commit, quotient_commit, .. } = commitment;
199
200        // Extract verification metadata.
201        let pcs = machine.config().pcs();
202
203        let permutation_challenges =
204            (0..2).map(|_| challenger.sample_ext_element::<SC::Challenge>()).collect::<Vec<_>>();
205
206        challenger.observe(permutation_commit.clone());
207
208        let alpha = challenger.sample_ext_element::<SC::Challenge>();
209
210        // Observe the quotient commitments.
211        challenger.observe(quotient_commit.clone());
212
213        let zeta = challenger.sample_ext_element::<SC::Challenge>();
214
215        let chips = machine.shard_chips_ordered(&proof.chip_ordering).collect::<Vec<_>>();
216
217        let log_degrees = opened_values.chips.iter().map(|val| val.log_degree).collect::<Vec<_>>();
218
219        let log_quotient_degrees =
220            chips.iter().map(|chip| chip.log_quotient_degree()).collect::<Vec<_>>();
221
222        let trace_domains = log_degrees
223            .iter()
224            .map(|log_degree| pcs.natural_domain_for_degree(1 << log_degree))
225            .collect::<Vec<_>>();
226
227        let quotient_chunk_domains = trace_domains
228            .iter()
229            .zip_eq(log_degrees)
230            .zip_eq(log_quotient_degrees)
231            .map(|((domain, log_degree), log_quotient_degree)| {
232                let quotient_degree = 1 << log_quotient_degree;
233                let quotient_domain =
234                    domain.create_disjoint_domain(1 << (log_degree + log_quotient_degree));
235                quotient_domain.split_domains(quotient_degree)
236            })
237            .collect::<Vec<_>>();
238
239        (chips, trace_domains, quotient_chunk_domains, permutation_challenges, alpha, zeta)
240    }
241
242    #[test]
243    fn test_verify_constraints() {
244        type SC = BabyBearPoseidon2;
245        type F = <SC as StarkGenericConfig>::Val;
246        type EF = <SC as StarkGenericConfig>::Challenge;
247        type A = RiscvAir<F>;
248
249        // Generate a dummy proof.
250        sp1_core_machine::utils::setup_logger();
251        let elf = include_bytes!("../../../../tests/fibonacci/elf/riscv32im-succinct-zkvm-elf");
252
253        let machine = A::machine(SC::default());
254        let (_, vk) = machine.setup(&Program::from(elf).unwrap());
255        let mut challenger = machine.config().challenger();
256        let (proof, _, _) = sp1_core_machine::utils::prove::<_, CpuProver<_, _>>(
257            Program::from(elf).unwrap(),
258            &SP1Stdin::new(),
259            SC::default(),
260            SP1CoreOpts::default(),
261        )
262        .unwrap();
263        machine.verify(&vk, &proof, &mut challenger).unwrap();
264
265        println!("Proof generated and verified successfully");
266        let mut challenger = machine.config().challenger();
267        vk.observe_into(&mut challenger);
268        proof.shard_proofs.iter().for_each(|proof| {
269            challenger.observe(proof.commitment.main_commit);
270            challenger.observe_slice(&proof.public_values[0..machine.num_pv_elts()]);
271        });
272
273        // Run the verify inside the DSL and compare it to the calculated value.
274        let mut builder = AsmBuilder::<F, EF>::default();
275
276        #[allow(clippy::never_loop)]
277        for proof in proof.shard_proofs.into_iter().take(1) {
278            let (
279                chips,
280                trace_domains_vals,
281                quotient_chunk_domains_vals,
282                permutation_challenges,
283                alpha_val,
284                zeta_val,
285            ) = get_shard_data(&machine, &proof, &mut challenger);
286
287            for (chip, trace_domain_val, qc_domains_vals, values_vals) in izip!(
288                chips.iter(),
289                trace_domains_vals,
290                quotient_chunk_domains_vals,
291                proof.opened_values.chips.iter(),
292            ) {
293                let opening = builder.constant(values_vals.clone());
294                let alpha = builder.eval(alpha_val.cons());
295                let zeta = builder.eval(zeta_val.cons());
296                let trace_domain = builder.constant(trace_domain_val);
297                let public_values = builder.constant(proof.public_values.clone());
298
299                let qc_domains = qc_domains_vals
300                    .iter()
301                    .map(|domain| builder.constant(*domain))
302                    .collect::<Vec<_>>();
303
304                let permutation_challenges = permutation_challenges
305                    .iter()
306                    .map(|c| builder.eval(c.cons()))
307                    .collect::<Vec<_>>();
308
309                StarkVerifier::<_, SC>::verify_constraints::<A>(
310                    &mut builder,
311                    chip,
312                    &opening,
313                    public_values,
314                    trace_domain,
315                    qc_domains,
316                    zeta,
317                    alpha,
318                    &permutation_challenges,
319                )
320            }
321            break;
322        }
323        builder.halt();
324
325        let program = builder.compile_program();
326        run_test_recursion(program, None, TestConfig::All);
327    }
328
329    #[test]
330    fn test_exp_reverse_bit_len_fast() {
331        type SC = BabyBearPoseidon2;
332        type F = <SC as StarkGenericConfig>::Val;
333        type EF = <SC as StarkGenericConfig>::Challenge;
334
335        let mut rng = thread_rng();
336
337        // Initialize a builder.
338        let mut builder = AsmBuilder::<F, EF>::default();
339
340        // Get a random var with `NUM_BITS` bits.
341        let x_val: F = rng.gen();
342
343        // Materialize the number as a var
344        let x_felt: Felt<_> = builder.eval(x_val);
345        let x_bits = builder.num2bits_f(x_felt);
346
347        let result = builder.exp_reverse_bits_len_fast(x_felt, &x_bits, 5);
348        let expected_val = builder.exp_reverse_bits_len(x_felt, &x_bits, 5);
349
350        builder.assert_felt_eq(expected_val, result);
351        builder.halt();
352
353        let program = builder.compile_program();
354
355        // We don't test with the config TestConfig::WideDeg17Wrap, since it doesn't have the
356        // `ExpReverseBitsLen` chip.
357        run_test_recursion(program.clone(), None, TestConfig::WideDeg3);
358        run_test_recursion(program, None, TestConfig::SkinnyDeg7);
359    }
360
361    #[test]
362    fn test_memory_finalize() {
363        type SC = BabyBearPoseidon2;
364        type F = <SC as StarkGenericConfig>::Val;
365        type EF = <SC as StarkGenericConfig>::Challenge;
366
367        let mut rng = thread_rng();
368
369        // Initialize a builder.
370        let mut builder = AsmBuilder::<F, EF>::default();
371
372        // Get a random var with `NUM_BITS` bits.
373        let x_val: F = rng.gen();
374
375        // Materialize the number as a var
376        let _x_felt: Felt<_> = builder.eval(x_val);
377
378        builder.halt();
379
380        let program = builder.compile_program();
381
382        run_test_recursion(program, None, TestConfig::All);
383    }
384}