sp1_recursion_core_v2/
machine.rs

1use p3_field::{extension::BinomiallyExtendable, PrimeField32};
2use sp1_recursion_core::runtime::D;
3use sp1_stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS};
4
5use crate::chips::{
6    alu_base::BaseAluChip,
7    alu_ext::ExtAluChip,
8    dummy::DummyChip,
9    exp_reverse_bits::ExpReverseBitsLenChip,
10    fri_fold::FriFoldChip,
11    mem::{MemoryConstChip, MemoryVarChip},
12    poseidon2_skinny::Poseidon2SkinnyChip,
13    poseidon2_wide::Poseidon2WideChip,
14    public_values::PublicValuesChip,
15};
16
17#[derive(sp1_derive::MachineAir)]
18#[sp1_core_path = "sp1_core_machine"]
19#[execution_record_path = "crate::ExecutionRecord<F>"]
20#[program_path = "crate::RecursionProgram<F>"]
21#[builder_path = "crate::builder::SP1RecursionAirBuilder<F = F>"]
22#[eval_trait_bound = "AB::Var: 'static"]
23pub enum RecursionAir<
24    F: PrimeField32 + BinomiallyExtendable<D>,
25    const DEGREE: usize,
26    const COL_PADDING: usize,
27> {
28    // Program(ProgramChip<F>),
29    MemoryConst(MemoryConstChip<F>),
30    MemoryVar(MemoryVarChip<F>),
31    BaseAlu(BaseAluChip),
32    ExtAlu(ExtAluChip),
33    // Cpu(CpuChip<F, DEGREE>),
34    // MemoryGlobal(MemoryGlobalChip),
35    Poseidon2Skinny(Poseidon2SkinnyChip<DEGREE>),
36    Poseidon2Wide(Poseidon2WideChip<DEGREE>),
37    FriFold(FriFoldChip<DEGREE>),
38    // RangeCheck(RangeCheckChip<F>),
39    // Multi(MultiChip<DEGREE>),
40    ExpReverseBitsLen(ExpReverseBitsLenChip<DEGREE>),
41    PublicValues(PublicValuesChip),
42    DummyWide(DummyChip<COL_PADDING>),
43}
44
45impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize, const COL_PADDING: usize>
46    RecursionAir<F, DEGREE, COL_PADDING>
47{
48    /// A recursion machine that can have dynamic trace sizes.
49    pub fn machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
50        let chips = Self::get_all().into_iter().map(Chip::new).collect::<Vec<_>>();
51        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
52    }
53
54    /// A recursion machine that can have dynamic trace sizes, and uses the wide variant of
55    /// Poseidon2.
56    pub fn machine_wide<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
57        let chips = Self::get_all_wide().into_iter().map(Chip::new).collect::<Vec<_>>();
58        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
59    }
60
61    pub fn machine_with_padding<SC: StarkGenericConfig<Val = F>>(
62        config: SC,
63        fri_fold_padding: usize,
64        poseidon2_padding: usize,
65        erbl_padding: usize,
66    ) -> StarkMachine<SC, Self> {
67        let chips = Self::get_all_with_padding(fri_fold_padding, poseidon2_padding, erbl_padding)
68            .into_iter()
69            .map(Chip::new)
70            .collect::<Vec<_>>();
71        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
72    }
73
74    pub fn dummy_machine<SC: StarkGenericConfig<Val = F>>(
75        config: SC,
76        log_height: usize,
77    ) -> StarkMachine<SC, Self> {
78        let chips = vec![RecursionAir::DummyWide(DummyChip::new(log_height))];
79        StarkMachine::new(config, chips.into_iter().map(Chip::new).collect(), PROOF_MAX_NUM_PVS)
80    }
81    // /// A recursion machine with fixed trace sizes tuned to work specifically for the wrap layer.
82    // pub fn wrap_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
83    //     let chips = Self::get_wrap_all()
84    //         .into_iter()
85    //         .map(Chip::new)
86    //         .collect::<Vec<_>>();
87    //     StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
88    // }
89
90    // /// A recursion machine with fixed trace sizes tuned to work specifically for the wrap layer.
91    // pub fn wrap_machine_dyn<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC,
92    // Self> {     let chips = Self::get_wrap_dyn_all()
93    //         .into_iter()
94    //         .map(Chip::new)
95    //         .collect::<Vec<_>>();
96    //     StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
97    // }
98
99    pub fn get_all() -> Vec<Self> {
100        vec![
101            RecursionAir::MemoryConst(MemoryConstChip::default()),
102            RecursionAir::MemoryVar(MemoryVarChip::default()),
103            RecursionAir::BaseAlu(BaseAluChip::default()),
104            RecursionAir::ExtAlu(ExtAluChip::default()),
105            RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
106            // RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>::default()),
107            RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>::default()),
108            RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
109            RecursionAir::PublicValues(PublicValuesChip::default()),
110        ]
111    }
112
113    pub fn get_all_wide() -> Vec<Self> {
114        vec![
115            // RecursionAir::Program(ProgramChip::default()),
116            RecursionAir::MemoryConst(MemoryConstChip::default()),
117            RecursionAir::MemoryVar(MemoryVarChip::default()),
118            RecursionAir::BaseAlu(BaseAluChip::default()),
119            RecursionAir::ExtAlu(ExtAluChip::default()),
120            // RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
121            RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>::default()),
122            RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>::default()),
123            RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
124            RecursionAir::PublicValues(PublicValuesChip::default()),
125        ]
126    }
127
128    pub fn get_all_with_padding(
129        fri_fold_padding: usize,
130        poseidon2_padding: usize,
131        erbl_padding: usize,
132    ) -> Vec<Self> {
133        vec![
134            // RecursionAir::Program(ProgramChip::default()),
135            RecursionAir::MemoryConst(MemoryConstChip::default()),
136            RecursionAir::MemoryVar(MemoryVarChip::default()),
137            RecursionAir::BaseAlu(BaseAluChip::default()),
138            RecursionAir::ExtAlu(ExtAluChip::default()),
139            // RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>::default()),
140            RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE> {
141                fixed_log2_rows: Some(poseidon2_padding),
142                pad: true,
143            }),
144            RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE> {
145                fixed_log2_rows: Some(erbl_padding),
146                pad: true,
147            }),
148            RecursionAir::FriFold(FriFoldChip::<DEGREE> {
149                fixed_log2_rows: Some(fri_fold_padding),
150                pad: true,
151            }),
152            RecursionAir::PublicValues(PublicValuesChip::default()),
153        ]
154    }
155
156    // pub fn get_wrap_dyn_all() -> Vec<Self> {
157    //     once(RecursionAir::Program(ProgramChip))
158    //         .chain(once(RecursionAir::Cpu(CpuChip {
159    //             fixed_log2_rows: None,
160    //             _phantom: PhantomData,
161    //         })))
162    //         .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip {
163    //             fixed_log2_rows: None,
164    //         })))
165    //         .chain(once(RecursionAir::Multi(MultiChip {
166    //             fixed_log2_rows: None,
167    //         })))
168    //         .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
169    //         .chain(once(RecursionAir::ExpReverseBitsLen(
170    //             ExpReverseBitsLenChip::<DEGREE> {
171    //                 fixed_log2_rows: None,
172    //                 pad: true,
173    //             },
174    //         )))
175    //         .collect()
176    // }
177
178    // pub fn get_wrap_all() -> Vec<Self> {
179    //     once(RecursionAir::Program(ProgramChip))
180    //         .chain(once(RecursionAir::Cpu(CpuChip {
181    //             fixed_log2_rows: Some(19),
182    //             _phantom: PhantomData,
183    //         })))
184    //         .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip {
185    //             fixed_log2_rows: Some(20),
186    //         })))
187    //         .chain(once(RecursionAir::Multi(MultiChip {
188    //             fixed_log2_rows: Some(17),
189    //         })))
190    //         .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
191    //         .chain(once(RecursionAir::ExpReverseBitsLen(
192    //             ExpReverseBitsLenChip::<DEGREE> {
193    //                 fixed_log2_rows: None,
194    //                 pad: true,
195    //             },
196    //         )))
197    //         .collect()
198    // }
199}
200
201#[cfg(test)]
202pub mod tests {
203
204    use std::sync::Arc;
205
206    use machine::RecursionAir;
207    use p3_baby_bear::DiffusionMatrixBabyBear;
208    use p3_field::{
209        extension::{BinomialExtensionField, HasFrobenius},
210        AbstractExtensionField, AbstractField, Field,
211    };
212    use rand::prelude::*;
213    use sp1_core_machine::utils::run_test_machine;
214    use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
215
216    // TODO expand glob import
217    use crate::{runtime::instruction as instr, *};
218
219    type SC = BabyBearPoseidon2;
220    type F = <SC as StarkGenericConfig>::Val;
221    type EF = <SC as StarkGenericConfig>::Challenge;
222    type A = RecursionAir<F, 3, 0>;
223    type B = RecursionAir<F, 9, 0>;
224
225    /// Runs the given program on machines that use the wide and skinny Poseidon2 chips.
226    pub fn run_recursion_test_machines(program: RecursionProgram<F>) {
227        let program = Arc::new(program);
228        let mut runtime =
229            Runtime::<F, EF, DiffusionMatrixBabyBear>::new(program.clone(), SC::new().perm);
230        runtime.run().unwrap();
231
232        // Run with the poseidon2 wide chip.
233        let wide_machine = A::machine_wide(BabyBearPoseidon2::default());
234        let (pk, vk) = wide_machine.setup(&program);
235        let result = run_test_machine(vec![runtime.record.clone()], wide_machine, pk, vk);
236        if let Err(e) = result {
237            panic!("Verification failed: {:?}", e);
238        }
239
240        // Run with the poseidon2 skinny chip.
241        let skinny_machine = B::machine(BabyBearPoseidon2::compressed());
242        let (pk, vk) = skinny_machine.setup(&program);
243        let result = run_test_machine(vec![runtime.record], skinny_machine, pk, vk);
244        if let Err(e) = result {
245            panic!("Verification failed: {:?}", e);
246        }
247    }
248
249    fn test_instructions(instructions: Vec<Instruction<F>>) {
250        let program = RecursionProgram { instructions, ..Default::default() };
251        run_recursion_test_machines(program);
252    }
253
254    #[test]
255    pub fn fibonacci() {
256        let n = 10;
257
258        let instructions = once(instr::mem(MemAccessKind::Write, 1, 0, 0))
259            .chain(once(instr::mem(MemAccessKind::Write, 2, 1, 1)))
260            .chain((2..=n).map(|i| instr::base_alu(BaseAluOpcode::AddF, 2, i, i - 2, i - 1)))
261            .chain(once(instr::mem(MemAccessKind::Read, 1, n - 1, 34)))
262            .chain(once(instr::mem(MemAccessKind::Read, 2, n, 55)))
263            .collect::<Vec<_>>();
264
265        test_instructions(instructions);
266    }
267
268    #[test]
269    #[should_panic]
270    pub fn div_nonzero_by_zero() {
271        let instructions = vec![
272            instr::mem(MemAccessKind::Write, 1, 0, 0),
273            instr::mem(MemAccessKind::Write, 1, 1, 1),
274            instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
275            instr::mem(MemAccessKind::Read, 1, 2, 1),
276        ];
277
278        test_instructions(instructions);
279    }
280
281    #[test]
282    pub fn div_zero_by_zero() {
283        let instructions = vec![
284            instr::mem(MemAccessKind::Write, 1, 0, 0),
285            instr::mem(MemAccessKind::Write, 1, 1, 0),
286            instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
287            instr::mem(MemAccessKind::Read, 1, 2, 1),
288        ];
289
290        test_instructions(instructions);
291    }
292
293    #[test]
294    pub fn field_norm() {
295        let mut instructions = Vec::new();
296
297        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
298        let mut addr = 0;
299        for _ in 0..100 {
300            let inner: [F; 4] = std::iter::repeat_with(|| {
301                core::array::from_fn(|_| rng.sample(rand::distributions::Standard))
302            })
303            .find(|xs| !xs.iter().all(F::is_zero))
304            .unwrap();
305            let x = BinomialExtensionField::<F, D>::from_base_slice(&inner);
306            let gal = x.galois_group();
307
308            let mut acc = BinomialExtensionField::one();
309
310            instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr, acc));
311            for conj in gal {
312                instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr + 1, conj));
313                instructions.push(instr::ext_alu(ExtAluOpcode::MulE, 1, addr + 2, addr, addr + 1));
314
315                addr += 2;
316                acc *= conj;
317            }
318            let base_cmp: F = acc.as_base_slice()[0];
319            instructions.push(instr::mem_single(MemAccessKind::Read, 1, addr, base_cmp));
320            addr += 1;
321        }
322
323        test_instructions(instructions);
324    }
325}