1use p3_field::{extension::BinomiallyExtendable, PrimeField32};
2use sp1_recursion_core::runtime::D;
3use sp1_stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS};
4
5use crate::chips::{
6 alu_base::BaseAluChip,
7 alu_ext::ExtAluChip,
8 dummy::DummyChip,
9 exp_reverse_bits::ExpReverseBitsLenChip,
10 fri_fold::FriFoldChip,
11 mem::{MemoryConstChip, MemoryVarChip},
12 poseidon2_skinny::Poseidon2SkinnyChip,
13 poseidon2_wide::Poseidon2WideChip,
14 public_values::PublicValuesChip,
15};
16
17#[derive(sp1_derive::MachineAir)]
18#[sp1_core_path = "sp1_core_machine"]
19#[execution_record_path = "crate::ExecutionRecord<F>"]
20#[program_path = "crate::RecursionProgram<F>"]
21#[builder_path = "crate::builder::SP1RecursionAirBuilder<F = F>"]
22#[eval_trait_bound = "AB::Var: 'static"]
23pub enum RecursionAir<
24 F: PrimeField32 + BinomiallyExtendable<D>,
25 const DEGREE: usize,
26 const COL_PADDING: usize,
27> {
28 MemoryConst(MemoryConstChip<F>),
30 MemoryVar(MemoryVarChip<F>),
31 BaseAlu(BaseAluChip),
32 ExtAlu(ExtAluChip),
33 Poseidon2Skinny(Poseidon2SkinnyChip<DEGREE>),
36 Poseidon2Wide(Poseidon2WideChip<DEGREE>),
37 FriFold(FriFoldChip<DEGREE>),
38 ExpReverseBitsLen(ExpReverseBitsLenChip<DEGREE>),
41 PublicValues(PublicValuesChip),
42 DummyWide(DummyChip<COL_PADDING>),
43}
44
45impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize, const COL_PADDING: usize>
46 RecursionAir<F, DEGREE, COL_PADDING>
47{
48 pub fn machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
50 let chips = Self::get_all().into_iter().map(Chip::new).collect::<Vec<_>>();
51 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
52 }
53
54 pub fn machine_wide<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
57 let chips = Self::get_all_wide().into_iter().map(Chip::new).collect::<Vec<_>>();
58 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
59 }
60
61 pub fn machine_with_padding<SC: StarkGenericConfig<Val = F>>(
62 config: SC,
63 fri_fold_padding: usize,
64 poseidon2_padding: usize,
65 erbl_padding: usize,
66 ) -> StarkMachine<SC, Self> {
67 let chips = Self::get_all_with_padding(fri_fold_padding, poseidon2_padding, erbl_padding)
68 .into_iter()
69 .map(Chip::new)
70 .collect::<Vec<_>>();
71 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
72 }
73
74 pub fn dummy_machine<SC: StarkGenericConfig<Val = F>>(
75 config: SC,
76 log_height: usize,
77 ) -> StarkMachine<SC, Self> {
78 let chips = vec![RecursionAir::DummyWide(DummyChip::new(log_height))];
79 StarkMachine::new(config, chips.into_iter().map(Chip::new).collect(), PROOF_MAX_NUM_PVS)
80 }
81 pub fn get_all() -> Vec<Self> {
100 vec![
101 RecursionAir::MemoryConst(MemoryConstChip::default()),
102 RecursionAir::MemoryVar(MemoryVarChip::default()),
103 RecursionAir::BaseAlu(BaseAluChip::default()),
104 RecursionAir::ExtAlu(ExtAluChip::default()),
105 RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
106 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>::default()),
108 RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
109 RecursionAir::PublicValues(PublicValuesChip::default()),
110 ]
111 }
112
113 pub fn get_all_wide() -> Vec<Self> {
114 vec![
115 RecursionAir::MemoryConst(MemoryConstChip::default()),
117 RecursionAir::MemoryVar(MemoryVarChip::default()),
118 RecursionAir::BaseAlu(BaseAluChip::default()),
119 RecursionAir::ExtAlu(ExtAluChip::default()),
120 RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>::default()),
122 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>::default()),
123 RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
124 RecursionAir::PublicValues(PublicValuesChip::default()),
125 ]
126 }
127
128 pub fn get_all_with_padding(
129 fri_fold_padding: usize,
130 poseidon2_padding: usize,
131 erbl_padding: usize,
132 ) -> Vec<Self> {
133 vec![
134 RecursionAir::MemoryConst(MemoryConstChip::default()),
136 RecursionAir::MemoryVar(MemoryVarChip::default()),
137 RecursionAir::BaseAlu(BaseAluChip::default()),
138 RecursionAir::ExtAlu(ExtAluChip::default()),
139 RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE> {
141 fixed_log2_rows: Some(poseidon2_padding),
142 pad: true,
143 }),
144 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE> {
145 fixed_log2_rows: Some(erbl_padding),
146 pad: true,
147 }),
148 RecursionAir::FriFold(FriFoldChip::<DEGREE> {
149 fixed_log2_rows: Some(fri_fold_padding),
150 pad: true,
151 }),
152 RecursionAir::PublicValues(PublicValuesChip::default()),
153 ]
154 }
155
156 }
200
201#[cfg(test)]
202pub mod tests {
203
204 use std::sync::Arc;
205
206 use machine::RecursionAir;
207 use p3_baby_bear::DiffusionMatrixBabyBear;
208 use p3_field::{
209 extension::{BinomialExtensionField, HasFrobenius},
210 AbstractExtensionField, AbstractField, Field,
211 };
212 use rand::prelude::*;
213 use sp1_core_machine::utils::run_test_machine;
214 use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
215
216 use crate::{runtime::instruction as instr, *};
218
219 type SC = BabyBearPoseidon2;
220 type F = <SC as StarkGenericConfig>::Val;
221 type EF = <SC as StarkGenericConfig>::Challenge;
222 type A = RecursionAir<F, 3, 0>;
223 type B = RecursionAir<F, 9, 0>;
224
225 pub fn run_recursion_test_machines(program: RecursionProgram<F>) {
227 let program = Arc::new(program);
228 let mut runtime =
229 Runtime::<F, EF, DiffusionMatrixBabyBear>::new(program.clone(), SC::new().perm);
230 runtime.run().unwrap();
231
232 let wide_machine = A::machine_wide(BabyBearPoseidon2::default());
234 let (pk, vk) = wide_machine.setup(&program);
235 let result = run_test_machine(vec![runtime.record.clone()], wide_machine, pk, vk);
236 if let Err(e) = result {
237 panic!("Verification failed: {:?}", e);
238 }
239
240 let skinny_machine = B::machine(BabyBearPoseidon2::compressed());
242 let (pk, vk) = skinny_machine.setup(&program);
243 let result = run_test_machine(vec![runtime.record], skinny_machine, pk, vk);
244 if let Err(e) = result {
245 panic!("Verification failed: {:?}", e);
246 }
247 }
248
249 fn test_instructions(instructions: Vec<Instruction<F>>) {
250 let program = RecursionProgram { instructions, ..Default::default() };
251 run_recursion_test_machines(program);
252 }
253
254 #[test]
255 pub fn fibonacci() {
256 let n = 10;
257
258 let instructions = once(instr::mem(MemAccessKind::Write, 1, 0, 0))
259 .chain(once(instr::mem(MemAccessKind::Write, 2, 1, 1)))
260 .chain((2..=n).map(|i| instr::base_alu(BaseAluOpcode::AddF, 2, i, i - 2, i - 1)))
261 .chain(once(instr::mem(MemAccessKind::Read, 1, n - 1, 34)))
262 .chain(once(instr::mem(MemAccessKind::Read, 2, n, 55)))
263 .collect::<Vec<_>>();
264
265 test_instructions(instructions);
266 }
267
268 #[test]
269 #[should_panic]
270 pub fn div_nonzero_by_zero() {
271 let instructions = vec![
272 instr::mem(MemAccessKind::Write, 1, 0, 0),
273 instr::mem(MemAccessKind::Write, 1, 1, 1),
274 instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
275 instr::mem(MemAccessKind::Read, 1, 2, 1),
276 ];
277
278 test_instructions(instructions);
279 }
280
281 #[test]
282 pub fn div_zero_by_zero() {
283 let instructions = vec![
284 instr::mem(MemAccessKind::Write, 1, 0, 0),
285 instr::mem(MemAccessKind::Write, 1, 1, 0),
286 instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
287 instr::mem(MemAccessKind::Read, 1, 2, 1),
288 ];
289
290 test_instructions(instructions);
291 }
292
293 #[test]
294 pub fn field_norm() {
295 let mut instructions = Vec::new();
296
297 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
298 let mut addr = 0;
299 for _ in 0..100 {
300 let inner: [F; 4] = std::iter::repeat_with(|| {
301 core::array::from_fn(|_| rng.sample(rand::distributions::Standard))
302 })
303 .find(|xs| !xs.iter().all(F::is_zero))
304 .unwrap();
305 let x = BinomialExtensionField::<F, D>::from_base_slice(&inner);
306 let gal = x.galois_group();
307
308 let mut acc = BinomialExtensionField::one();
309
310 instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr, acc));
311 for conj in gal {
312 instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr + 1, conj));
313 instructions.push(instr::ext_alu(ExtAluOpcode::MulE, 1, addr + 2, addr, addr + 1));
314
315 addr += 2;
316 acc *= conj;
317 }
318 let base_cmp: F = acc.as_base_slice()[0];
319 instructions.push(instr::mem_single(MemAccessKind::Read, 1, addr, base_cmp));
320 addr += 1;
321 }
322
323 test_instructions(instructions);
324 }
325}