1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
pub mod config;
pub mod poseidon2;
pub mod utils;

use crate::{
    cpu::CpuChip, exp_reverse_bits::ExpReverseBitsLenChip, fri_fold::FriFoldChip,
    memory::MemoryGlobalChip, multi::MultiChip, poseidon2_wide::Poseidon2WideChip,
    program::ProgramChip, range_check::RangeCheckChip,
};
use core::iter::once;
use p3_field::{extension::BinomiallyExtendable, PrimeField32};
use sp1_stark::{Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS};
use std::marker::PhantomData;

use crate::runtime::D;

pub type RecursionAirWideDeg3<F> = RecursionAir<F, 3>;
pub type RecursionAirWideDeg9<F> = RecursionAir<F, 9>;
pub type RecursionAirWideDeg17<F> = RecursionAir<F, 17>;

#[derive(sp1_derive::MachineAir)]
#[sp1_core_path = "sp1_stark"]
#[execution_record_path = "crate::runtime::ExecutionRecord<F>"]
#[program_path = "crate::runtime::RecursionProgram<F>"]
#[builder_path = "crate::air::SP1RecursionAirBuilder<F = F>"]
#[eval_trait_bound = "AB::Var: 'static"]
pub enum RecursionAir<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> {
    Program(ProgramChip),
    Cpu(CpuChip<F, DEGREE>),
    MemoryGlobal(MemoryGlobalChip),
    Poseidon2Wide(Poseidon2WideChip<DEGREE>),
    FriFold(FriFoldChip<DEGREE>),
    RangeCheck(RangeCheckChip<F>),
    Multi(MultiChip<DEGREE>),
    ExpReverseBitsLen(ExpReverseBitsLenChip<DEGREE>),
}

impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAir<F, DEGREE> {
    /// A recursion machine that can have dynamic trace sizes.
    pub fn machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
        let chips = Self::get_all().into_iter().map(Chip::new).collect::<Vec<_>>();
        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
    }

    /// A recursion machine with fixed trace sizes tuned to work specifically for the wrap layer.
    pub fn wrap_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
        let chips = Self::get_wrap_all().into_iter().map(Chip::new).collect::<Vec<_>>();
        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
    }

    /// A recursion machine with fixed trace sizes tuned to work specifically for the wrap layer.
    pub fn wrap_machine_dyn<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
        let chips = Self::get_wrap_dyn_all().into_iter().map(Chip::new).collect::<Vec<_>>();
        StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS)
    }

    pub fn get_all() -> Vec<Self> {
        once(RecursionAir::Program(ProgramChip))
            .chain(once(RecursionAir::Cpu(CpuChip {
                fixed_log2_rows: None,
                _phantom: PhantomData,
            })))
            .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: None })))
            .chain(once(RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE> {
                fixed_log2_rows: None,
                pad: true,
            })))
            .chain(once(RecursionAir::FriFold(FriFoldChip::<DEGREE> {
                fixed_log2_rows: None,
                pad: true,
            })))
            .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
            .chain(once(RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE> {
                fixed_log2_rows: None,
                pad: true,
            })))
            .collect()
    }

    pub fn get_wrap_dyn_all() -> Vec<Self> {
        once(RecursionAir::Program(ProgramChip))
            .chain(once(RecursionAir::Cpu(CpuChip {
                fixed_log2_rows: None,
                _phantom: PhantomData,
            })))
            .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: None })))
            .chain(once(RecursionAir::Multi(MultiChip { fixed_log2_rows: None })))
            .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
            .chain(once(RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE> {
                fixed_log2_rows: None,
                pad: true,
            })))
            .collect()
    }

    pub fn get_wrap_all() -> Vec<Self> {
        once(RecursionAir::Program(ProgramChip))
            .chain(once(RecursionAir::Cpu(CpuChip {
                fixed_log2_rows: Some(20),
                _phantom: PhantomData,
            })))
            .chain(once(RecursionAir::MemoryGlobal(MemoryGlobalChip { fixed_log2_rows: Some(19) })))
            .chain(once(RecursionAir::Multi(MultiChip { fixed_log2_rows: Some(17) })))
            .chain(once(RecursionAir::RangeCheck(RangeCheckChip::default())))
            .collect()
    }
}