1use std::ops::{Add, AddAssign};
2
3use hashbrown::HashMap;
4use p3_field::{extension::BinomiallyExtendable, PrimeField32};
5use sp1_stark::{
6 air::{InteractionScope, MachineAir},
7 shape::OrderedShape,
8 Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS,
9};
10
11use crate::{
12 chips::{
13 alu_base::{BaseAluChip, NUM_BASE_ALU_ENTRIES_PER_ROW},
14 alu_ext::{ExtAluChip, NUM_EXT_ALU_ENTRIES_PER_ROW},
15 batch_fri::BatchFRIChip,
16 exp_reverse_bits::ExpReverseBitsLenChip,
17 fri_fold::FriFoldChip,
18 mem::{
19 constant::NUM_CONST_MEM_ENTRIES_PER_ROW, variable::NUM_VAR_MEM_ENTRIES_PER_ROW,
20 MemoryConstChip, MemoryVarChip,
21 },
22 poseidon2_skinny::Poseidon2SkinnyChip,
23 poseidon2_wide::Poseidon2WideChip,
24 public_values::{PublicValuesChip, PUB_VALUES_LOG_HEIGHT},
25 select::SelectChip,
26 },
27 instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr},
28 shape::RecursionShape,
29 ExpReverseBitsInstr, Instruction, RecursionProgram, D,
30};
31
32#[derive(sp1_derive::MachineAir)]
33#[sp1_core_path = "sp1_core_machine"]
34#[execution_record_path = "crate::ExecutionRecord<F>"]
35#[program_path = "crate::RecursionProgram<F>"]
36#[builder_path = "crate::builder::SP1RecursionAirBuilder<F = F>"]
37#[eval_trait_bound = "AB::Var: 'static"]
38pub enum RecursionAir<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> {
39 MemoryConst(MemoryConstChip<F>),
40 MemoryVar(MemoryVarChip<F>),
41 BaseAlu(BaseAluChip),
42 ExtAlu(ExtAluChip),
43 Poseidon2Skinny(Poseidon2SkinnyChip<DEGREE>),
44 Poseidon2Wide(Poseidon2WideChip<DEGREE>),
45 Select(SelectChip),
46 FriFold(FriFoldChip<DEGREE>),
47 BatchFRI(BatchFRIChip<DEGREE>),
48 ExpReverseBitsLen(ExpReverseBitsLenChip<DEGREE>),
49 PublicValues(PublicValuesChip),
50}
51
52#[derive(Debug, Clone, Copy, Default)]
53pub struct RecursionAirEventCount {
54 pub mem_const_events: usize,
55 pub mem_var_events: usize,
56 pub base_alu_events: usize,
57 pub ext_alu_events: usize,
58 pub poseidon2_wide_events: usize,
59 pub fri_fold_events: usize,
60 pub batch_fri_events: usize,
61 pub select_events: usize,
62 pub exp_reverse_bits_len_events: usize,
63}
64
65impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAir<F, DEGREE> {
66 pub fn machine_wide_with_all_chips<SC: StarkGenericConfig<Val = F>>(
68 config: SC,
69 ) -> StarkMachine<SC, Self> {
70 let chips = [
71 RecursionAir::MemoryConst(MemoryConstChip::default()),
72 RecursionAir::MemoryVar(MemoryVarChip::default()),
73 RecursionAir::BaseAlu(BaseAluChip),
74 RecursionAir::ExtAlu(ExtAluChip),
75 RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>),
76 RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
77 RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
78 RecursionAir::Select(SelectChip),
79 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
80 RecursionAir::PublicValues(PublicValuesChip),
81 ]
82 .map(Chip::new)
83 .into_iter()
84 .collect::<Vec<_>>();
85 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
86 }
87
88 pub fn machine_skinny_with_all_chips<SC: StarkGenericConfig<Val = F>>(
90 config: SC,
91 ) -> StarkMachine<SC, Self> {
92 let chips = [
93 RecursionAir::MemoryConst(MemoryConstChip::default()),
94 RecursionAir::MemoryVar(MemoryVarChip::default()),
95 RecursionAir::BaseAlu(BaseAluChip),
96 RecursionAir::ExtAlu(ExtAluChip),
97 RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
98 RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
99 RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
100 RecursionAir::Select(SelectChip),
101 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
102 RecursionAir::PublicValues(PublicValuesChip),
103 ]
104 .map(Chip::new)
105 .into_iter()
106 .collect::<Vec<_>>();
107 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
108 }
109
110 pub fn compress_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
112 let chips = [
113 RecursionAir::MemoryConst(MemoryConstChip::default()),
114 RecursionAir::MemoryVar(MemoryVarChip::default()),
115 RecursionAir::BaseAlu(BaseAluChip),
116 RecursionAir::ExtAlu(ExtAluChip),
117 RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>),
118 RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
119 RecursionAir::Select(SelectChip),
120 RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
121 RecursionAir::PublicValues(PublicValuesChip),
122 ]
123 .map(Chip::new)
124 .into_iter()
125 .collect::<Vec<_>>();
126 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
127 }
128
129 pub fn shrink_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
130 Self::compress_machine(config)
131 }
132
133 pub fn wrap_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
138 let chips = [
139 RecursionAir::MemoryConst(MemoryConstChip::default()),
140 RecursionAir::MemoryVar(MemoryVarChip::default()),
141 RecursionAir::BaseAlu(BaseAluChip),
142 RecursionAir::ExtAlu(ExtAluChip),
143 RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
144 RecursionAir::Select(SelectChip),
146 RecursionAir::PublicValues(PublicValuesChip),
147 ]
148 .map(Chip::new)
149 .into_iter()
150 .collect::<Vec<_>>();
151 StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
152 }
153
154 pub fn shrink_shape() -> RecursionShape {
155 let shape = HashMap::from(
156 [
157 (Self::MemoryVar(MemoryVarChip::default()), 18),
158 (Self::Select(SelectChip), 18),
159 (Self::MemoryConst(MemoryConstChip::default()), 17),
160 (Self::BatchFRI(BatchFRIChip::<DEGREE>), 17),
161 (Self::BaseAlu(BaseAluChip), 17),
162 (Self::ExtAlu(ExtAluChip), 18),
163 (Self::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>), 17),
164 (Self::Poseidon2Wide(Poseidon2WideChip::<DEGREE>), 16),
165 (Self::PublicValues(PublicValuesChip), PUB_VALUES_LOG_HEIGHT),
166 ]
167 .map(|(chip, log_height)| (chip.name(), log_height)),
168 );
169 RecursionShape { inner: shape }
170 }
171
172 pub fn heights(program: &RecursionProgram<F>) -> Vec<(String, usize)> {
173 let heights = program
174 .inner
175 .iter()
176 .fold(RecursionAirEventCount::default(), |heights, instruction| heights + instruction);
177
178 [
179 (
180 Self::MemoryConst(MemoryConstChip::default()),
181 heights.mem_const_events.div_ceil(NUM_CONST_MEM_ENTRIES_PER_ROW),
182 ),
183 (
184 Self::MemoryVar(MemoryVarChip::default()),
185 heights.mem_var_events.div_ceil(NUM_VAR_MEM_ENTRIES_PER_ROW),
186 ),
187 (
188 Self::BaseAlu(BaseAluChip),
189 heights.base_alu_events.div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW),
190 ),
191 (
192 Self::ExtAlu(ExtAluChip),
193 heights.ext_alu_events.div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW),
194 ),
195 (Self::Poseidon2Wide(Poseidon2WideChip::<DEGREE>), heights.poseidon2_wide_events),
196 (Self::BatchFRI(BatchFRIChip::<DEGREE>), heights.batch_fri_events),
197 (Self::Select(SelectChip), heights.select_events),
198 (
199 Self::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
200 heights.exp_reverse_bits_len_events,
201 ),
202 (Self::PublicValues(PublicValuesChip), PUB_VALUES_LOG_HEIGHT),
203 ]
204 .map(|(chip, log_height)| (chip.name(), log_height))
205 .to_vec()
206 }
207}
208
209impl<F> AddAssign<&Instruction<F>> for RecursionAirEventCount {
210 #[inline]
211 fn add_assign(&mut self, rhs: &Instruction<F>) {
212 match rhs {
213 Instruction::BaseAlu(_) => self.base_alu_events += 1,
214 Instruction::ExtAlu(_) => self.ext_alu_events += 1,
215 Instruction::Mem(_) => self.mem_const_events += 1,
216 Instruction::Poseidon2(_) => self.poseidon2_wide_events += 1,
217 Instruction::Select(_) => self.select_events += 1,
218 Instruction::ExpReverseBitsLen(ExpReverseBitsInstr { addrs, .. }) => {
219 self.exp_reverse_bits_len_events += addrs.exp.len()
220 }
221 Instruction::Hint(HintInstr { output_addrs_mults }) |
222 Instruction::HintBits(HintBitsInstr {
223 output_addrs_mults,
224 input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
226 Instruction::HintExt2Felts(HintExt2FeltsInstr {
227 output_addrs_mults,
228 input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
230 Instruction::FriFold(_) => self.fri_fold_events += 1,
231 Instruction::BatchFRI(instr) => {
232 self.batch_fri_events += instr.base_vec_addrs.p_at_x.len()
233 }
234 Instruction::HintAddCurve(instr) => {
235 self.mem_var_events += instr.output_x_addrs_mults.len();
236 self.mem_var_events += instr.output_y_addrs_mults.len();
237 }
238 Instruction::CommitPublicValues(_) => {}
239 Instruction::Print(_) => {}
240 #[cfg(feature = "debug")]
241 Instruction::DebugBacktrace(_) => {}
242 }
243 }
244}
245
246impl<F> Add<&Instruction<F>> for RecursionAirEventCount {
247 type Output = Self;
248
249 #[inline]
250 fn add(mut self, rhs: &Instruction<F>) -> Self::Output {
251 self += rhs;
252 self
253 }
254}
255
256impl From<RecursionShape> for OrderedShape {
257 fn from(value: RecursionShape) -> Self {
258 value.inner.into_iter().collect()
259 }
260}
261
262#[cfg(test)]
263pub mod tests {
264
265 use std::{iter::once, sync::Arc};
266
267 use machine::RecursionAir;
268 use p3_baby_bear::DiffusionMatrixBabyBear;
269 use p3_field::{
270 extension::{BinomialExtensionField, HasFrobenius},
271 AbstractExtensionField, AbstractField, Field,
272 };
273 use rand::prelude::*;
274 use sp1_core_machine::utils::run_test_machine;
275 use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
276
277 use crate::{runtime::instruction as instr, *};
279
280 type SC = BabyBearPoseidon2;
281 type F = <SC as StarkGenericConfig>::Val;
282 type EF = <SC as StarkGenericConfig>::Challenge;
283 type A = RecursionAir<F, 3>;
284 type B = RecursionAir<F, 9>;
285
286 pub fn run_recursion_test_machines(program: RecursionProgram<F>) {
288 let program = Arc::new(program);
289 let mut runtime =
290 Runtime::<F, EF, DiffusionMatrixBabyBear>::new(program.clone(), SC::new().perm);
291 runtime.run().unwrap();
292
293 let machine = A::machine_wide_with_all_chips(BabyBearPoseidon2::default());
295 let (pk, vk) = machine.setup(&program);
296 run_test_machine(vec![runtime.record.clone()], machine, pk, vk)
297 .expect("Verification failed");
298
299 let skinny_machine =
301 B::machine_skinny_with_all_chips(BabyBearPoseidon2::ultra_compressed());
302 let (pk, vk) = skinny_machine.setup(&program);
303 run_test_machine(vec![runtime.record], skinny_machine, pk, vk)
304 .expect("Verification failed");
305 }
306
307 pub fn test_recursion_linear_program(instrs: Vec<Instruction<F>>) {
310 run_recursion_test_machines(linear_program(instrs).unwrap());
311 }
312
313 #[test]
314 pub fn fibonacci() {
315 let n = 10;
316
317 let instructions = once(instr::mem(MemAccessKind::Write, 1, 0, 0))
318 .chain(once(instr::mem(MemAccessKind::Write, 2, 1, 1)))
319 .chain((2..=n).map(|i| instr::base_alu(BaseAluOpcode::AddF, 2, i, i - 2, i - 1)))
320 .chain(once(instr::mem(MemAccessKind::Read, 1, n - 1, 34)))
321 .chain(once(instr::mem(MemAccessKind::Read, 2, n, 55)))
322 .collect::<Vec<_>>();
323
324 test_recursion_linear_program(instructions);
325 }
326
327 #[test]
328 #[should_panic]
329 pub fn div_nonzero_by_zero() {
330 let instructions = vec![
331 instr::mem(MemAccessKind::Write, 1, 0, 0),
332 instr::mem(MemAccessKind::Write, 1, 1, 1),
333 instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
334 instr::mem(MemAccessKind::Read, 1, 2, 1),
335 ];
336
337 test_recursion_linear_program(instructions);
338 }
339
340 #[test]
341 pub fn div_zero_by_zero() {
342 let instructions = vec![
343 instr::mem(MemAccessKind::Write, 1, 0, 0),
344 instr::mem(MemAccessKind::Write, 1, 1, 0),
345 instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
346 instr::mem(MemAccessKind::Read, 1, 2, 1),
347 ];
348
349 test_recursion_linear_program(instructions);
350 }
351
352 #[test]
353 pub fn field_norm() {
354 let mut instructions = Vec::new();
355
356 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
357 let mut addr = 0;
358 for _ in 0..100 {
359 let inner: [F; 4] = std::iter::repeat_with(|| {
360 core::array::from_fn(|_| rng.sample(rand::distributions::Standard))
361 })
362 .find(|xs| !xs.iter().all(F::is_zero))
363 .unwrap();
364 let x = BinomialExtensionField::<F, D>::from_base_slice(&inner);
365 let gal = x.galois_group();
366
367 let mut acc = BinomialExtensionField::one();
368
369 instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr, acc));
370 for conj in gal {
371 instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr + 1, conj));
372 instructions.push(instr::ext_alu(ExtAluOpcode::MulE, 1, addr + 2, addr, addr + 1));
373
374 addr += 2;
375 acc *= conj;
376 }
377 let base_cmp: F = acc.as_base_slice()[0];
378 instructions.push(instr::mem_single(MemAccessKind::Read, 1, addr, base_cmp));
379 addr += 1;
380 }
381
382 test_recursion_linear_program(instructions);
383 }
384}