#![allow(clippy::needless_range_loop)]
use crate::{
air::{Block, IsZeroOperation, RecursionMemoryAirBuilder},
memory::{MemoryReadSingleCols, MemoryReadWriteSingleCols},
runtime::Opcode,
};
use core::borrow::Borrow;
use p3_air::{Air, AirBuilder, BaseAir};
use p3_field::{AbstractField, PrimeField32};
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_util::reverse_bits_len;
use sp1_core_machine::utils::{next_power_of_two, par_for_each_row};
use sp1_derive::AlignedBorrow;
use sp1_stark::air::{BaseAirBuilder, ExtensionAirBuilder, MachineAir, SP1AirBuilder};
use std::borrow::BorrowMut;
use tracing::instrument;
use crate::{
air::SP1RecursionAirBuilder,
memory::MemoryRecord,
runtime::{ExecutionRecord, RecursionProgram},
};
pub const NUM_EXP_REVERSE_BITS_LEN_COLS: usize = core::mem::size_of::<ExpReverseBitsLenCols<u8>>();
#[derive(Default)]
pub struct ExpReverseBitsLenChip<const DEGREE: usize> {
pub fixed_log2_rows: Option<usize>,
pub pad: bool,
}
#[derive(Debug, Clone)]
pub struct ExpReverseBitsLenEvent<F> {
pub clk: F,
pub x: MemoryRecord<F>,
pub current_bit: MemoryRecord<F>,
pub len: F,
pub prev_accum: F,
pub accum: F,
pub ptr: F,
pub base_ptr: F,
pub iteration_num: F,
}
impl<F: PrimeField32> ExpReverseBitsLenEvent<F> {
pub fn dummy_from_input(x: F, exponent: u32, len: F, timestamp: F) -> Vec<Self> {
let mut events = Vec::new();
let mut new_len = len;
let mut new_exponent = exponent;
let mut accum = F::one();
for i in 0..len.as_canonical_u32() {
let current_bit = new_exponent % 2;
let prev_accum = accum;
accum = prev_accum * prev_accum * if current_bit == 0 { F::one() } else { x };
events.push(Self {
clk: timestamp + F::from_canonical_u32(i),
x: MemoryRecord::new_write(
F::one(),
Block::from([
if i == len.as_canonical_u32() - 1 { accum } else { x },
F::zero(),
F::zero(),
F::zero(),
]),
timestamp + F::from_canonical_u32(i),
Block::from([x, F::zero(), F::zero(), F::zero()]),
timestamp + F::from_canonical_u32(i) - F::one(),
),
current_bit: MemoryRecord::new_read(
F::zero(),
Block::from([
F::from_canonical_u32(current_bit),
F::zero(),
F::zero(),
F::zero(),
]),
timestamp + F::from_canonical_u32(i),
timestamp + F::from_canonical_u32(i) - F::one(),
),
len: new_len,
prev_accum,
accum,
ptr: F::from_canonical_u32(i),
base_ptr: F::one(),
iteration_num: F::from_canonical_u32(i),
});
new_exponent /= 2;
new_len -= F::one();
}
assert_eq!(
accum,
x.exp_u64(reverse_bits_len(exponent as usize, len.as_canonical_u32() as usize) as u64)
);
events
}
}
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct ExpReverseBitsLenCols<T: Copy> {
pub clk: T,
pub x: MemoryReadWriteSingleCols<T>,
pub len: T,
pub current_bit: MemoryReadSingleCols<T>,
pub prev_accum_squared: T,
pub accum: T,
pub is_last: IsZeroOperation<T>,
pub is_first: IsZeroOperation<T>,
pub iteration_num: T,
pub multiplier: T,
pub ptr: T,
pub base_ptr: T,
pub x_mem_access_flag: T,
pub is_real: T,
}
impl<F, const DEGREE: usize> BaseAir<F> for ExpReverseBitsLenChip<DEGREE> {
fn width(&self) -> usize {
NUM_EXP_REVERSE_BITS_LEN_COLS
}
}
impl<F: PrimeField32, const DEGREE: usize> MachineAir<F> for ExpReverseBitsLenChip<DEGREE> {
type Record = ExecutionRecord<F>;
type Program = RecursionProgram<F>;
fn name(&self) -> String {
"ExpReverseBitsLen".to_string()
}
fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
}
#[instrument(name = "generate exp reverse bits len trace", level = "debug", skip_all, fields(rows = input.exp_reverse_bits_len_events.len()))]
fn generate_trace(
&self,
input: &ExecutionRecord<F>,
_: &mut ExecutionRecord<F>,
) -> RowMajorMatrix<F> {
let nb_events = input.exp_reverse_bits_len_events.len();
let nb_rows =
if self.pad { next_power_of_two(nb_events, self.fixed_log2_rows) } else { nb_events };
let mut values = vec![F::zero(); nb_rows * NUM_EXP_REVERSE_BITS_LEN_COLS];
par_for_each_row(&mut values, NUM_EXP_REVERSE_BITS_LEN_COLS, |i, row| {
if i >= nb_events {
return;
}
let event = &input.exp_reverse_bits_len_events[i];
let cols: &mut ExpReverseBitsLenCols<F> = row.borrow_mut();
cols.clk = event.clk;
cols.x.populate(&event.x);
cols.current_bit.populate(&event.current_bit);
cols.len = event.len;
cols.accum = event.accum;
cols.prev_accum_squared = event.prev_accum * event.prev_accum;
cols.is_last.populate(F::one() - event.len);
cols.is_first.populate(event.iteration_num);
cols.is_real = F::one();
cols.iteration_num = event.iteration_num;
cols.multiplier =
if event.current_bit.value == Block([F::one(), F::zero(), F::zero(), F::zero()]) {
event.x.prev_value[0]
} else {
F::one()
};
cols.ptr = event.ptr;
cols.base_ptr = event.base_ptr;
cols.x_mem_access_flag =
F::from_bool(cols.len == F::one() || cols.iteration_num == F::zero());
});
let trace = RowMajorMatrix::new(values, NUM_EXP_REVERSE_BITS_LEN_COLS);
#[cfg(debug_assertions)]
println!(
"exp reverse bits len trace dims is width: {:?}, height: {:?}",
trace.width(),
trace.height()
);
trace
}
fn included(&self, record: &Self::Record) -> bool {
!record.exp_reverse_bits_len_events.is_empty()
}
}
impl<const DEGREE: usize> ExpReverseBitsLenChip<DEGREE> {
pub fn eval_exp_reverse_bits_len<
AB: BaseAirBuilder + ExtensionAirBuilder + RecursionMemoryAirBuilder + SP1AirBuilder,
>(
&self,
builder: &mut AB,
local: &ExpReverseBitsLenCols<AB::Var>,
next: &ExpReverseBitsLenCols<AB::Var>,
memory_access: AB::Var,
) {
if DEGREE > 3 {
let lhs = (0..DEGREE).map(|_| local.is_real.into()).product::<AB::Expr>();
let rhs = (0..DEGREE).map(|_| local.is_real.into()).product::<AB::Expr>();
builder.assert_eq(lhs, rhs);
}
let operands =
[local.clk.into(), local.base_ptr.into(), local.ptr.into(), local.len.into()];
builder.receive_table(
Opcode::ExpReverseBitsLen.as_field::<AB::F>(),
&operands,
local.is_first.result,
);
builder.when_not(local.is_real).assert_zero(local.is_first.result);
IsZeroOperation::<AB::F>::eval(
builder,
AB::Expr::one() - local.len,
local.is_last,
local.is_real.into(),
);
IsZeroOperation::<AB::F>::eval(
builder,
local.iteration_num.into(),
local.is_first,
local.is_real.into(),
);
builder.when_transition().assert_zero((AB::Expr::one() - local.is_real) * next.is_real);
builder.assert_bool(local.is_real);
let current_bit_val = local.current_bit.access.value;
builder.assert_bool(current_bit_val);
builder.when_first_row().assert_one(local.is_first.result);
builder
.when_transition()
.when(next.is_real * local.is_last.result)
.assert_one(next.is_first.result);
builder.when(local.is_first.result).assert_eq(local.accum, local.multiplier);
builder
.when_transition()
.when(local.is_real * (AB::Expr::one() - next.is_real))
.assert_one(local.is_last.result);
builder.when_last_row().when(local.is_real).assert_one(local.is_last.result);
builder.when(current_bit_val).assert_eq(local.multiplier, local.x.prev_value);
builder
.when(local.is_real)
.when_not(current_bit_val)
.assert_eq(local.multiplier, AB::Expr::one());
builder
.when_not(local.is_first.result)
.assert_eq(local.accum, local.prev_accum_squared * local.multiplier);
builder
.when_transition()
.when_not(local.is_last.result)
.assert_eq(next.prev_accum_squared, local.accum * local.accum);
builder
.when_transition()
.when_not(local.is_last.result)
.assert_eq(local.base_ptr, next.base_ptr);
builder
.when_transition()
.when(next.is_real)
.when_not(local.is_last.result)
.assert_eq(next.ptr, local.ptr + AB::Expr::one());
builder
.when_transition()
.when(local.is_real)
.when_not(local.is_last.result)
.assert_eq(local.len, next.len + AB::Expr::one());
builder
.when_transition()
.when(local.is_real)
.when_not(local.is_last.result)
.assert_eq(local.iteration_num + AB::Expr::one(), next.iteration_num);
builder.when(local.is_first.result).assert_eq(local.iteration_num, AB::Expr::zero());
builder.recursion_eval_memory_access_single(
local.clk,
local.ptr,
&local.current_bit,
memory_access,
);
builder.when(local.is_real).assert_eq(
local.x_mem_access_flag,
local.is_first.result + local.is_last.result
- local.is_first.result * local.is_last.result,
);
builder.when_not(local.is_real).assert_zero(local.x_mem_access_flag);
builder.recursion_eval_memory_access_single(
local.clk,
local.base_ptr,
&local.x,
local.x_mem_access_flag,
);
builder
.when_transition()
.when(next.is_real)
.when_not(local.is_last.result)
.assert_eq(next.base_ptr, local.base_ptr);
builder
.when_transition()
.when_not(local.is_last.result)
.when(next.is_real)
.assert_eq(local.clk + AB::Expr::one(), next.clk);
builder
.when_transition()
.when(next.is_real)
.when_not(local.is_last.result)
.assert_eq(local.x.access.value, next.x.prev_value);
builder
.when_transition()
.when_not(local.is_last.result)
.assert_eq(local.x.access.value, local.x.prev_value);
builder.when(local.is_last.result).assert_eq(local.accum, local.x.access.value);
}
pub const fn do_exp_bit_memory_access<T: Copy>(local: &ExpReverseBitsLenCols<T>) -> T {
local.is_real
}
}
impl<AB, const DEGREE: usize> Air<AB> for ExpReverseBitsLenChip<DEGREE>
where
AB: SP1RecursionAirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let (local, next) = (main.row_slice(0), main.row_slice(1));
let local: &ExpReverseBitsLenCols<AB::Var> = (*local).borrow();
let next: &ExpReverseBitsLenCols<AB::Var> = (*next).borrow();
self.eval_exp_reverse_bits_len::<AB>(
builder,
local,
next,
Self::do_exp_bit_memory_access::<AB::Var>(local),
);
}
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use sp1_stark::{air::MachineAir, baby_bear_poseidon2::BabyBearPoseidon2, StarkGenericConfig};
use std::time::Instant;
use p3_baby_bear::{BabyBear, DiffusionMatrixBabyBear};
use p3_field::AbstractField;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use p3_poseidon2::{Poseidon2, Poseidon2ExternalMatrixGeneral};
use sp1_core_machine::utils::{uni_stark_prove, uni_stark_verify};
use crate::{
exp_reverse_bits::{ExpReverseBitsLenChip, ExpReverseBitsLenEvent},
runtime::ExecutionRecord,
};
#[test]
fn prove_babybear() {
let config = BabyBearPoseidon2::compressed();
let mut challenger = config.challenger();
let chip = ExpReverseBitsLenChip::<5> { pad: true, fixed_log2_rows: None };
let test_xs = (1..16).map(BabyBear::from_canonical_u32).collect_vec();
let test_exponents = (1..16).collect_vec();
let mut input_exec = ExecutionRecord::<BabyBear>::default();
for (x, exponent) in test_xs.into_iter().zip_eq(test_exponents) {
let mut events = ExpReverseBitsLenEvent::dummy_from_input(
x,
exponent,
BabyBear::from_canonical_u32(exponent.ilog2() + 1),
x,
);
input_exec.exp_reverse_bits_len_events.append(&mut events);
}
println!("input exec: {:?}", input_exec.exp_reverse_bits_len_events.len());
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&input_exec, &mut ExecutionRecord::<BabyBear>::default());
println!("trace dims is width: {:?}, height: {:?}", trace.width(), trace.height());
let start = Instant::now();
let proof = uni_stark_prove(&config, &chip, &mut challenger, trace);
let duration = start.elapsed().as_secs_f64();
println!("proof duration = {:?}", duration);
let mut challenger: p3_challenger::DuplexChallenger<
BabyBear,
Poseidon2<BabyBear, Poseidon2ExternalMatrixGeneral, DiffusionMatrixBabyBear, 16, 7>,
16,
8,
> = config.challenger();
let start = Instant::now();
uni_stark_verify(&config, &chip, &mut challenger, &proof)
.expect("expected proof to be valid");
let duration = start.elapsed().as_secs_f64();
println!("verify duration = {:?}", duration);
}
}