use std::ops::{Add, AddAssign};
use hashbrown::HashMap;
use p3_field::{extension::BinomiallyExtendable, PrimeField32};
use sp1_stark::{
air::{InteractionScope, MachineAir},
shape::OrderedShape,
Chip, StarkGenericConfig, StarkMachine, PROOF_MAX_NUM_PVS,
};
use crate::{
chips::{
alu_base::{BaseAluChip, NUM_BASE_ALU_ENTRIES_PER_ROW},
alu_ext::{ExtAluChip, NUM_EXT_ALU_ENTRIES_PER_ROW},
batch_fri::BatchFRIChip,
exp_reverse_bits::ExpReverseBitsLenChip,
fri_fold::FriFoldChip,
mem::{
constant::NUM_CONST_MEM_ENTRIES_PER_ROW, variable::NUM_VAR_MEM_ENTRIES_PER_ROW,
MemoryConstChip, MemoryVarChip,
},
poseidon2_skinny::Poseidon2SkinnyChip,
poseidon2_wide::Poseidon2WideChip,
public_values::{PublicValuesChip, PUB_VALUES_LOG_HEIGHT},
select::SelectChip,
},
instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr},
shape::RecursionShape,
ExpReverseBitsInstr, Instruction, RecursionProgram, D,
};
#[derive(sp1_derive::MachineAir)]
#[sp1_core_path = "sp1_core_machine"]
#[execution_record_path = "crate::ExecutionRecord<F>"]
#[program_path = "crate::RecursionProgram<F>"]
#[builder_path = "crate::builder::SP1RecursionAirBuilder<F = F>"]
#[eval_trait_bound = "AB::Var: 'static"]
pub enum RecursionAir<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> {
MemoryConst(MemoryConstChip<F>),
MemoryVar(MemoryVarChip<F>),
BaseAlu(BaseAluChip),
ExtAlu(ExtAluChip),
Poseidon2Skinny(Poseidon2SkinnyChip<DEGREE>),
Poseidon2Wide(Poseidon2WideChip<DEGREE>),
Select(SelectChip),
FriFold(FriFoldChip<DEGREE>),
BatchFRI(BatchFRIChip<DEGREE>),
ExpReverseBitsLen(ExpReverseBitsLenChip<DEGREE>),
PublicValues(PublicValuesChip),
}
#[derive(Debug, Clone, Copy, Default)]
pub struct RecursionAirEventCount {
pub mem_const_events: usize,
pub mem_var_events: usize,
pub base_alu_events: usize,
pub ext_alu_events: usize,
pub poseidon2_wide_events: usize,
pub fri_fold_events: usize,
pub batch_fri_events: usize,
pub select_events: usize,
pub exp_reverse_bits_len_events: usize,
}
impl<F: PrimeField32 + BinomiallyExtendable<D>, const DEGREE: usize> RecursionAir<F, DEGREE> {
pub fn machine_wide_with_all_chips<SC: StarkGenericConfig<Val = F>>(
config: SC,
) -> StarkMachine<SC, Self> {
let chips = [
RecursionAir::MemoryConst(MemoryConstChip::default()),
RecursionAir::MemoryVar(MemoryVarChip::default()),
RecursionAir::BaseAlu(BaseAluChip),
RecursionAir::ExtAlu(ExtAluChip),
RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>),
RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
RecursionAir::Select(SelectChip),
RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
RecursionAir::PublicValues(PublicValuesChip),
]
.map(Chip::new)
.into_iter()
.collect::<Vec<_>>();
StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
}
pub fn machine_skinny_with_all_chips<SC: StarkGenericConfig<Val = F>>(
config: SC,
) -> StarkMachine<SC, Self> {
let chips = [
RecursionAir::MemoryConst(MemoryConstChip::default()),
RecursionAir::MemoryVar(MemoryVarChip::default()),
RecursionAir::BaseAlu(BaseAluChip),
RecursionAir::ExtAlu(ExtAluChip),
RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
RecursionAir::FriFold(FriFoldChip::<DEGREE>::default()),
RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
RecursionAir::Select(SelectChip),
RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
RecursionAir::PublicValues(PublicValuesChip),
]
.map(Chip::new)
.into_iter()
.collect::<Vec<_>>();
StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
}
pub fn compress_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
let chips = [
RecursionAir::MemoryConst(MemoryConstChip::default()),
RecursionAir::MemoryVar(MemoryVarChip::default()),
RecursionAir::BaseAlu(BaseAluChip),
RecursionAir::ExtAlu(ExtAluChip),
RecursionAir::Poseidon2Wide(Poseidon2WideChip::<DEGREE>),
RecursionAir::BatchFRI(BatchFRIChip::<DEGREE>),
RecursionAir::Select(SelectChip),
RecursionAir::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
RecursionAir::PublicValues(PublicValuesChip),
]
.map(Chip::new)
.into_iter()
.collect::<Vec<_>>();
StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
}
pub fn shrink_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
Self::compress_machine(config)
}
pub fn wrap_machine<SC: StarkGenericConfig<Val = F>>(config: SC) -> StarkMachine<SC, Self> {
let chips = [
RecursionAir::MemoryConst(MemoryConstChip::default()),
RecursionAir::MemoryVar(MemoryVarChip::default()),
RecursionAir::BaseAlu(BaseAluChip),
RecursionAir::ExtAlu(ExtAluChip),
RecursionAir::Poseidon2Skinny(Poseidon2SkinnyChip::<DEGREE>::default()),
RecursionAir::Select(SelectChip),
RecursionAir::PublicValues(PublicValuesChip),
]
.map(Chip::new)
.into_iter()
.collect::<Vec<_>>();
StarkMachine::new(config, chips, PROOF_MAX_NUM_PVS, false)
}
pub fn shrink_shape() -> RecursionShape {
let shape = HashMap::from(
[
(Self::MemoryVar(MemoryVarChip::default()), 18),
(Self::Select(SelectChip), 18),
(Self::MemoryConst(MemoryConstChip::default()), 17),
(Self::BatchFRI(BatchFRIChip::<DEGREE>), 17),
(Self::BaseAlu(BaseAluChip), 17),
(Self::ExtAlu(ExtAluChip), 18),
(Self::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>), 17),
(Self::Poseidon2Wide(Poseidon2WideChip::<DEGREE>), 16),
(Self::PublicValues(PublicValuesChip), PUB_VALUES_LOG_HEIGHT),
]
.map(|(chip, log_height)| (chip.name(), log_height)),
);
RecursionShape { inner: shape }
}
pub fn heights(program: &RecursionProgram<F>) -> Vec<(String, usize)> {
let heights = program
.inner
.iter()
.fold(RecursionAirEventCount::default(), |heights, instruction| heights + instruction);
[
(
Self::MemoryConst(MemoryConstChip::default()),
heights.mem_const_events.div_ceil(NUM_CONST_MEM_ENTRIES_PER_ROW),
),
(
Self::MemoryVar(MemoryVarChip::default()),
heights.mem_var_events.div_ceil(NUM_VAR_MEM_ENTRIES_PER_ROW),
),
(
Self::BaseAlu(BaseAluChip),
heights.base_alu_events.div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW),
),
(
Self::ExtAlu(ExtAluChip),
heights.ext_alu_events.div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW),
),
(Self::Poseidon2Wide(Poseidon2WideChip::<DEGREE>), heights.poseidon2_wide_events),
(Self::BatchFRI(BatchFRIChip::<DEGREE>), heights.batch_fri_events),
(Self::Select(SelectChip), heights.select_events),
(
Self::ExpReverseBitsLen(ExpReverseBitsLenChip::<DEGREE>),
heights.exp_reverse_bits_len_events,
),
(Self::PublicValues(PublicValuesChip), PUB_VALUES_LOG_HEIGHT),
]
.map(|(chip, log_height)| (chip.name(), log_height))
.to_vec()
}
}
impl<F> AddAssign<&Instruction<F>> for RecursionAirEventCount {
#[inline]
fn add_assign(&mut self, rhs: &Instruction<F>) {
match rhs {
Instruction::BaseAlu(_) => self.base_alu_events += 1,
Instruction::ExtAlu(_) => self.ext_alu_events += 1,
Instruction::Mem(_) => self.mem_const_events += 1,
Instruction::Poseidon2(_) => self.poseidon2_wide_events += 1,
Instruction::Select(_) => self.select_events += 1,
Instruction::ExpReverseBitsLen(ExpReverseBitsInstr { addrs, .. }) => {
self.exp_reverse_bits_len_events += addrs.exp.len()
}
Instruction::Hint(HintInstr { output_addrs_mults })
| Instruction::HintBits(HintBitsInstr {
output_addrs_mults,
input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
Instruction::HintExt2Felts(HintExt2FeltsInstr {
output_addrs_mults,
input_addr: _, }) => self.mem_var_events += output_addrs_mults.len(),
Instruction::FriFold(_) => self.fri_fold_events += 1,
Instruction::BatchFRI(instr) => {
self.batch_fri_events += instr.base_vec_addrs.p_at_x.len()
}
Instruction::HintAddCurve(instr) => {
self.mem_var_events += instr.output_x_addrs_mults.len();
self.mem_var_events += instr.output_y_addrs_mults.len();
}
Instruction::CommitPublicValues(_) => {}
Instruction::Print(_) => {}
#[cfg(feature = "debug")]
Instruction::DebugBacktrace(_) => {}
}
}
}
impl<F> Add<&Instruction<F>> for RecursionAirEventCount {
type Output = Self;
#[inline]
fn add(mut self, rhs: &Instruction<F>) -> Self::Output {
self += rhs;
self
}
}
impl From<RecursionShape> for OrderedShape {
fn from(value: RecursionShape) -> Self {
value.inner.into_iter().collect()
}
}
#[cfg(test)]
pub mod tests {
use std::{iter::once, sync::Arc};
use machine::RecursionAir;
use p3_baby_bear::DiffusionMatrixBabyBear;
use p3_field::{
extension::{BinomialExtensionField, HasFrobenius},
AbstractExtensionField, AbstractField, Field,
};
use rand::prelude::*;
use sp1_core_machine::utils::run_test_machine;
use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
use crate::{runtime::instruction as instr, *};
type SC = BabyBearPoseidon2;
type F = <SC as StarkGenericConfig>::Val;
type EF = <SC as StarkGenericConfig>::Challenge;
type A = RecursionAir<F, 3>;
type B = RecursionAir<F, 9>;
pub fn run_recursion_test_machines(program: RecursionProgram<F>) {
let program = Arc::new(program);
let mut runtime =
Runtime::<F, EF, DiffusionMatrixBabyBear>::new(program.clone(), SC::new().perm);
runtime.run().unwrap();
let machine = A::machine_wide_with_all_chips(BabyBearPoseidon2::default());
let (pk, vk) = machine.setup(&program);
run_test_machine(vec![runtime.record.clone()], machine, pk, vk)
.expect("Verification failed");
let skinny_machine =
B::machine_skinny_with_all_chips(BabyBearPoseidon2::ultra_compressed());
let (pk, vk) = skinny_machine.setup(&program);
run_test_machine(vec![runtime.record], skinny_machine, pk, vk)
.expect("Verification failed");
}
pub fn test_recursion_linear_program(instrs: Vec<Instruction<F>>) {
run_recursion_test_machines(linear_program(instrs).unwrap());
}
#[test]
pub fn fibonacci() {
let n = 10;
let instructions = once(instr::mem(MemAccessKind::Write, 1, 0, 0))
.chain(once(instr::mem(MemAccessKind::Write, 2, 1, 1)))
.chain((2..=n).map(|i| instr::base_alu(BaseAluOpcode::AddF, 2, i, i - 2, i - 1)))
.chain(once(instr::mem(MemAccessKind::Read, 1, n - 1, 34)))
.chain(once(instr::mem(MemAccessKind::Read, 2, n, 55)))
.collect::<Vec<_>>();
test_recursion_linear_program(instructions);
}
#[test]
#[should_panic]
pub fn div_nonzero_by_zero() {
let instructions = vec![
instr::mem(MemAccessKind::Write, 1, 0, 0),
instr::mem(MemAccessKind::Write, 1, 1, 1),
instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
instr::mem(MemAccessKind::Read, 1, 2, 1),
];
test_recursion_linear_program(instructions);
}
#[test]
pub fn div_zero_by_zero() {
let instructions = vec![
instr::mem(MemAccessKind::Write, 1, 0, 0),
instr::mem(MemAccessKind::Write, 1, 1, 0),
instr::base_alu(BaseAluOpcode::DivF, 1, 2, 1, 0),
instr::mem(MemAccessKind::Read, 1, 2, 1),
];
test_recursion_linear_program(instructions);
}
#[test]
pub fn field_norm() {
let mut instructions = Vec::new();
let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
let mut addr = 0;
for _ in 0..100 {
let inner: [F; 4] = std::iter::repeat_with(|| {
core::array::from_fn(|_| rng.sample(rand::distributions::Standard))
})
.find(|xs| !xs.iter().all(F::is_zero))
.unwrap();
let x = BinomialExtensionField::<F, D>::from_base_slice(&inner);
let gal = x.galois_group();
let mut acc = BinomialExtensionField::one();
instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr, acc));
for conj in gal {
instructions.push(instr::mem_ext(MemAccessKind::Write, 1, addr + 1, conj));
instructions.push(instr::ext_alu(ExtAluOpcode::MulE, 1, addr + 2, addr, addr + 1));
addr += 2;
acc *= conj;
}
let base_cmp: F = acc.as_base_slice()[0];
instructions.push(instr::mem_single(MemAccessKind::Read, 1, addr, base_cmp));
addr += 1;
}
test_recursion_linear_program(instructions);
}
}