use core::borrow::{Borrow, BorrowMut};
use core::mem::size_of;
use std::array;
use p3_air::BaseAir;
use p3_air::{Air, AirBuilder};
use p3_field::{AbstractField, PrimeField32};
use p3_matrix::dense::RowMajorMatrix;
use p3_matrix::Matrix;
use sp1_derive::AlignedBorrow;
use super::MemoryInitializeFinalizeEvent;
use crate::air::{AirInteraction, BaseAirBuilder, PublicValues, SP1AirBuilder, Word};
use crate::air::{MachineAir, SP1_PROOF_NUM_PV_ELTS};
use crate::operations::{AssertLtColsBits, BabyBearBitDecomposition, IsZeroOperation};
use crate::runtime::{ExecutionRecord, Program};
use crate::utils::pad_to_power_of_two;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MemoryChipType {
Initialize,
Finalize,
}
pub struct MemoryChip {
pub kind: MemoryChipType,
}
impl MemoryChip {
pub const fn new(kind: MemoryChipType) -> Self {
Self { kind }
}
}
impl<F> BaseAir<F> for MemoryChip {
fn width(&self) -> usize {
NUM_MEMORY_INIT_COLS
}
}
impl<F: PrimeField32> MachineAir<F> for MemoryChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
match self.kind {
MemoryChipType::Initialize => "MemoryInit".to_string(),
MemoryChipType::Finalize => "MemoryFinalize".to_string(),
}
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut memory_events = match self.kind {
MemoryChipType::Initialize => input.memory_initialize_events.clone(),
MemoryChipType::Finalize => input.memory_finalize_events.clone(),
};
let previous_addr_bits = match self.kind {
MemoryChipType::Initialize => input.public_values.previous_init_addr_bits,
MemoryChipType::Finalize => input.public_values.previous_finalize_addr_bits,
};
memory_events.sort_by_key(|event| event.addr);
let rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = (0..memory_events.len()) .map(|i| {
let MemoryInitializeFinalizeEvent {
addr,
value,
shard,
timestamp,
used,
} = memory_events[i];
let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
cols.addr = F::from_canonical_u32(addr);
cols.addr_bits.populate(addr);
cols.shard = F::from_canonical_u32(shard);
cols.timestamp = F::from_canonical_u32(timestamp);
cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1));
cols.is_real = F::from_canonical_u32(used);
if i == 0 {
let prev_addr = previous_addr_bits
.iter()
.enumerate()
.map(|(j, bit)| bit * (1 << j))
.sum::<u32>();
cols.is_prev_addr_zero.populate(prev_addr);
cols.is_first_comp = F::from_bool(prev_addr != 0);
if prev_addr != 0 {
debug_assert!(prev_addr < addr, "prev_addr {} < addr {}", prev_addr, addr);
let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
cols.lt_cols.populate(&previous_addr_bits, &addr_bits);
}
}
if i != 0 {
let prev_is_real = memory_events[i - 1].used;
cols.is_next_comp = F::from_canonical_u32(prev_is_real);
let previous_addr = memory_events[i - 1].addr;
assert_ne!(previous_addr, addr);
let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
let prev_addr_bits: [_; 32] = array::from_fn(|i| (previous_addr >> i) & 1);
cols.lt_cols.populate(&prev_addr_bits, &addr_bits);
}
if i == memory_events.len() - 1 {
cols.is_last_addr = F::one();
}
row
})
.collect::<Vec<_>>();
let mut trace = RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_MEMORY_INIT_COLS,
);
pad_to_power_of_two::<NUM_MEMORY_INIT_COLS, F>(&mut trace.values);
trace
}
fn included(&self, shard: &Self::Record) -> bool {
match self.kind {
MemoryChipType::Initialize => !shard.memory_initialize_events.is_empty(),
MemoryChipType::Finalize => !shard.memory_finalize_events.is_empty(),
}
}
}
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct MemoryInitCols<T> {
pub shard: T,
pub timestamp: T,
pub addr: T,
pub lt_cols: AssertLtColsBits<T, 32>,
pub addr_bits: BabyBearBitDecomposition<T>,
pub value: [T; 32],
pub is_real: T,
pub is_next_comp: T,
pub is_prev_addr_zero: IsZeroOperation<T>,
pub is_first_comp: T,
pub is_last_addr: T,
}
pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();
impl<AB> Air<AB> for MemoryChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &MemoryInitCols<AB::Var> = (*local).borrow();
let next = main.row_slice(1);
let next: &MemoryInitCols<AB::Var> = (*next).borrow();
builder.assert_bool(local.is_real);
for i in 0..32 {
builder.assert_bool(local.value[i]);
}
let mut byte1 = AB::Expr::zero();
let mut byte2 = AB::Expr::zero();
let mut byte3 = AB::Expr::zero();
let mut byte4 = AB::Expr::zero();
for i in 0..8 {
byte1 += local.value[i].into() * AB::F::from_canonical_u8(1 << i);
byte2 += local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i);
byte3 += local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i);
byte4 += local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i);
}
let value = [byte1, byte2, byte3, byte4];
if self.kind == MemoryChipType::Initialize {
let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), local.addr.into()];
values.extend(value.map(Into::into));
builder.receive(AirInteraction::new(
values,
local.is_real.into(),
crate::lookup::InteractionKind::Memory,
));
} else {
let mut values = vec![
local.shard.into(),
local.timestamp.into(),
local.addr.into(),
];
values.extend(value);
builder.send(AirInteraction::new(
values,
local.is_real.into(),
crate::lookup::InteractionKind::Memory,
));
}
BabyBearBitDecomposition::<AB::F>::range_check(
builder,
local.addr,
local.addr_bits,
local.is_real.into(),
);
builder
.when_transition()
.assert_eq(next.is_next_comp, next.is_real);
next.lt_cols.eval(
builder,
&local.addr_bits.bits,
&next.addr_bits.bits,
next.is_next_comp,
);
builder
.when_transition()
.when_not(local.is_real)
.assert_zero(next.is_real);
let local_addr_bits = local.addr_bits.bits;
let public_values_array: [AB::Expr; SP1_PROOF_NUM_PV_ELTS] =
array::from_fn(|i| builder.public_values()[i].into());
let public_values: &PublicValues<Word<AB::Expr>, AB::Expr> =
public_values_array.as_slice().borrow();
let prev_addr_bits = match self.kind {
MemoryChipType::Initialize => &public_values.previous_init_addr_bits,
MemoryChipType::Finalize => &public_values.previous_finalize_addr_bits,
};
let prev_addr = prev_addr_bits
.iter()
.enumerate()
.map(|(i, bit)| bit.clone() * AB::F::from_wrapped_u32(1 << i))
.sum::<AB::Expr>();
let is_first_row = builder.is_first_row();
IsZeroOperation::<AB::F>::eval(builder, prev_addr, local.is_prev_addr_zero, is_first_row);
builder.assert_bool(local.is_first_comp);
builder.when_first_row().assert_eq(
local.is_first_comp,
AB::Expr::one() - local.is_prev_addr_zero.result,
);
builder.when_first_row().assert_one(local.is_real);
local.lt_cols.eval(
builder,
prev_addr_bits,
&local_addr_bits,
local.is_first_comp,
);
builder
.when_first_row()
.when(local.is_prev_addr_zero.result)
.assert_zero(local.addr);
builder
.when_first_row()
.when(local.is_prev_addr_zero.result)
.assert_one(next.is_real);
builder
.when_first_row()
.when(local.is_prev_addr_zero.result)
.assert_one(next.is_next_comp);
if self.kind == MemoryChipType::Initialize {
builder
.when(local.is_real)
.assert_eq(local.timestamp, AB::F::one());
}
for i in 0..32 {
builder
.when_first_row()
.when_not(local.is_first_comp)
.assert_zero(local.value[i]);
}
let last_addr_bits = match self.kind {
MemoryChipType::Initialize => &public_values.last_init_addr_bits,
MemoryChipType::Finalize => &public_values.last_finalize_addr_bits,
};
builder.when_transition().assert_eq(
local.is_last_addr,
local.is_real * (AB::Expr::one() - next.is_real),
);
for (local_bit, pub_bit) in local.addr_bits.bits.iter().zip(last_addr_bits.iter()) {
builder
.when_last_row()
.when(local.is_real)
.assert_eq(*local_bit, pub_bit.clone());
builder
.when_transition()
.when(local.is_last_addr)
.assert_eq(*local_bit, pub_bit.clone());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lookup::{debug_interactions_with_all_chips, InteractionKind};
use crate::runtime::tests::simple_program;
use crate::runtime::Runtime;
use crate::stark::RiscvAir;
use crate::syscall::precompiles::sha256::extend_tests::sha_extend_program;
use crate::utils::{setup_logger, BabyBearPoseidon2, SP1CoreOpts};
use p3_baby_bear::BabyBear;
#[test]
fn test_memory_generate_trace() {
let program = simple_program();
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let shard = runtime.record.clone();
let chip: MemoryChip = MemoryChip::new(MemoryChipType::Initialize);
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values);
let chip: MemoryChip = MemoryChip::new(MemoryChipType::Finalize);
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values);
for mem_event in shard.memory_finalize_events {
println!("{:?}", mem_event);
}
}
#[test]
fn test_memory_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let machine: crate::stark::StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
RiscvAir::machine(BabyBearPoseidon2::new());
let (pkey, _) = machine.setup(&program_clone);
let opts = SP1CoreOpts::default();
machine.generate_dependencies(&mut runtime.records, &opts);
let shards = runtime.records;
assert_eq!(shards.len(), 2);
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&shards,
vec![InteractionKind::Memory],
);
}
#[test]
fn test_byte_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Runtime::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let machine = RiscvAir::machine(BabyBearPoseidon2::new());
let (pkey, _) = machine.setup(&program_clone);
let opts = SP1CoreOpts::default();
machine.generate_dependencies(&mut runtime.records, &opts);
let shards = runtime.records;
assert_eq!(shards.len(), 2);
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&shards,
vec![InteractionKind::Byte],
);
}
}