mod alu;
mod branch;
mod heap;
mod jump;
mod memory;
mod operands;
mod public_values;
mod system;
use std::borrow::Borrow;
use p3_air::{Air, AirBuilder};
use p3_field::{AbstractField, Field};
use p3_matrix::Matrix;
use sp1_stark::air::BaseAirBuilder;
use crate::{
air::{RecursionPublicValues, SP1RecursionAirBuilder, RECURSIVE_PROOF_NUM_PV_ELTS},
cpu::{columns::SELECTOR_COL_MAP, CpuChip, CpuCols},
memory::MemoryCols,
};
impl<AB, const L: usize> Air<AB> for CpuChip<AB::F, L>
where
AB: SP1RecursionAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &CpuCols<AB::Var> = (*local).borrow();
let next: &CpuCols<AB::Var> = (*next).borrow();
let pv = builder.public_values();
let pv_elms: [AB::Expr; RECURSIVE_PROOF_NUM_PV_ELTS] =
core::array::from_fn(|i| pv[i].into());
let public_values: &RecursionPublicValues<AB::Expr> = pv_elms.as_slice().borrow();
let zero = AB::Expr::zero();
let one = AB::Expr::one();
builder.when_not(local.is_real).assert_one(local.instruction.imm_b);
builder.when_not(local.is_real).assert_one(local.instruction.imm_c);
builder.when_not(local.is_real).assert_one(local.selectors.is_noop);
local
.selectors
.into_iter()
.enumerate()
.filter(|(i, _)| *i != SELECTOR_COL_MAP.is_noop)
.for_each(|(_, selector)| builder.when_not(local.is_real).assert_zero(selector));
builder.when_first_row().assert_zero(local.clk);
builder.when_first_row().assert_zero(local.pc);
builder.send_program(local.pc, local.instruction, local.selectors, local.is_real);
self.eval_operands(builder, local);
self.eval_memory(builder, local);
self.eval_alu(builder, local);
{
let mut next_pc = zero;
self.eval_branch(builder, local, &mut next_pc);
self.eval_jump(builder, local, next, &mut next_pc);
let not_branch_or_jump = one.clone()
- self.is_branch_instruction::<AB>(local)
- self.is_jump_instruction::<AB>(local);
next_pc += not_branch_or_jump.clone() * (local.pc + one);
builder.when_transition().when(next.is_real).assert_eq(next_pc, next.pc);
}
let send_syscall = local.selectors.is_poseidon
+ local.selectors.is_fri_fold
+ local.selectors.is_exp_reverse_bits_len;
let operands = [
local.clk.into(),
local.a.value()[0].into(),
local.b.value()[0].into(),
local.c.value()[0] + local.instruction.offset_imm,
];
builder.send_table(local.instruction.opcode, &operands, send_syscall);
self.eval_commit(builder, local, public_values.digest.clone());
self.eval_clk(builder, local, next);
self.eval_system_instructions(builder, local, next, public_values);
self.eval_heap_ptr(builder, local);
self.eval_is_real(builder, local, next);
let mut expr = local.is_real * local.is_real;
for _ in 0..(L - 2) {
expr *= local.is_real.into();
}
builder.assert_eq(expr.clone(), expr.clone());
}
}
impl<F: Field, const L: usize> CpuChip<F, L> {
pub fn eval_clk<AB>(&self, builder: &mut AB, local: &CpuCols<AB::Var>, next: &CpuCols<AB::Var>)
where
AB: SP1RecursionAirBuilder,
{
builder
.when_transition()
.when(next.is_real)
.when_not(local.selectors.is_fri_fold + local.selectors.is_exp_reverse_bits_len)
.assert_eq(local.clk.into() + AB::F::from_canonical_u32(4), next.clk);
builder
.when_transition()
.when(next.is_real)
.when(local.selectors.is_fri_fold)
.assert_eq(local.clk.into() + local.a.value()[0], next.clk);
builder
.when_transition()
.when(next.is_real)
.when(local.selectors.is_exp_reverse_bits_len)
.assert_eq(local.clk.into() + local.c.value()[0], next.clk);
}
pub fn eval_is_real<AB>(
&self,
builder: &mut AB,
local: &CpuCols<AB::Var>,
next: &CpuCols<AB::Var>,
) where
AB: SP1RecursionAirBuilder,
{
builder.assert_bool(local.is_real);
builder.when_first_row().assert_one(local.is_real);
builder.when_transition().when_not(local.is_real).assert_zero(next.is_real);
}
pub fn is_alu_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_add
+ local.selectors.is_sub
+ local.selectors.is_mul
+ local.selectors.is_div
}
pub fn is_branch_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_beq + local.selectors.is_bne + local.selectors.is_bneinc
}
pub fn is_jump_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_jal + local.selectors.is_jalr
}
pub fn is_memory_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_load + local.selectors.is_store
}
pub fn is_op_a_read_only_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_beq
+ local.selectors.is_bne
+ local.selectors.is_fri_fold
+ local.selectors.is_poseidon
+ local.selectors.is_store
+ local.selectors.is_noop
+ local.selectors.is_ext_to_felt
+ local.selectors.is_commit
+ local.selectors.is_trap
+ local.selectors.is_halt
+ local.selectors.is_exp_reverse_bits_len
}
pub fn is_commit_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_commit.into()
}
pub fn is_system_instruction<AB>(&self, local: &CpuCols<AB::Var>) -> AB::Expr
where
AB: SP1RecursionAirBuilder<F = F>,
{
local.selectors.is_trap + local.selectors.is_halt
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use sp1_stark::{baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
use std::time::Instant;
use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear};
use p3_field::AbstractField;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
use sp1_core_machine::utils::{uni_stark_prove, uni_stark_verify};
use sp1_stark::air::MachineAir;
use crate::{air::Block, memory::MemoryGlobalChip, runtime::ExecutionRecord};
#[test]
fn test_cpu_unistark() {
let config = BabyBearPoseidon2::compressed();
let mut challenger = config.challenger();
let chip = MemoryGlobalChip { fixed_log2_rows: None };
let test_vals = (0..16).map(BabyBear::from_canonical_u32).collect_vec();
let mut input_exec = ExecutionRecord::<BabyBear>::default();
for val in test_vals.into_iter() {
let event = (val, val, Block::from(BabyBear::zero()));
input_exec.last_memory_record.push(event);
}
input_exec.first_memory_record.push((BabyBear::zero(), Block::from(BabyBear::zero())));
println!("input exec: {:?}", input_exec.last_memory_record.len());
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
println!("trace dims is width: {:?}, height: {:?}", trace.width(), trace.height());
let start = Instant::now();
let proof = uni_stark_prove(&config, &chip, &mut challenger, trace);
let duration = start.elapsed().as_secs_f64();
println!("proof duration = {:?}", duration);
let mut challenger: p3_challenger::DuplexChallenger<
BabyBear,
Poseidon2<BabyBear, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>,
16,
8,
> = config.challenger();
let start = Instant::now();
uni_stark_verify(&config, &chip, &mut challenger, &proof)
.expect("expected proof to be valid");
let duration = start.elapsed().as_secs_f64();
println!("verify duration = {:?}", duration);
}
}