use std::{
borrow::{Borrow, BorrowMut},
mem::size_of,
};
use crate::utils::pad_rows_fixed;
use itertools::Itertools;
use p3_air::{Air, BaseAir};
use p3_field::PrimeField32;
use p3_matrix::{dense::RowMajorMatrix, Matrix};
use sp1_core_executor::{ExecutionRecord, Program};
use sp1_derive::AlignedBorrow;
use sp1_stark::{
air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
InteractionKind, Word,
};
pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 4;
pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::<MemoryLocalCols<u8>>();
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
struct SingleMemoryLocal<T> {
pub addr: T,
pub initial_shard: T,
pub final_shard: T,
pub initial_clk: T,
pub final_clk: T,
pub initial_value: Word<T>,
pub final_value: Word<T>,
pub is_real: T,
}
#[derive(AlignedBorrow, Debug, Clone, Copy)]
#[repr(C)]
pub struct MemoryLocalCols<T> {
memory_local_entries: [SingleMemoryLocal<T>; NUM_LOCAL_MEMORY_ENTRIES_PER_ROW],
}
pub struct MemoryLocalChip {}
impl MemoryLocalChip {
pub const fn new() -> Self {
Self {}
}
}
impl<F> BaseAir<F> for MemoryLocalChip {
fn width(&self) -> usize {
NUM_MEMORY_LOCAL_INIT_COLS
}
}
impl<F: PrimeField32> MachineAir<F> for MemoryLocalChip {
type Record = ExecutionRecord;
type Program = Program;
fn name(&self) -> String {
"MemoryLocal".to_string()
}
fn generate_dependencies(&self, _input: &ExecutionRecord, _output: &mut ExecutionRecord) {
}
fn generate_trace(
&self,
input: &ExecutionRecord,
_output: &mut ExecutionRecord,
) -> RowMajorMatrix<F> {
let mut rows = Vec::<[F; NUM_MEMORY_LOCAL_INIT_COLS]>::new();
for local_mem_events in
&input.get_local_mem_events().chunks(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
{
let mut row = [F::zero(); NUM_MEMORY_LOCAL_INIT_COLS];
let cols: &mut MemoryLocalCols<F> = row.as_mut_slice().borrow_mut();
for (cols, event) in cols.memory_local_entries.iter_mut().zip(local_mem_events) {
cols.addr = F::from_canonical_u32(event.addr);
cols.initial_shard = F::from_canonical_u32(event.initial_mem_access.shard);
cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard);
cols.initial_clk = F::from_canonical_u32(event.initial_mem_access.timestamp);
cols.final_clk = F::from_canonical_u32(event.final_mem_access.timestamp);
cols.initial_value = event.initial_mem_access.value.into();
cols.final_value = event.final_mem_access.value.into();
cols.is_real = F::one();
}
rows.push(row);
}
pad_rows_fixed(
&mut rows,
|| [F::zero(); NUM_MEMORY_LOCAL_INIT_COLS],
input.fixed_log2_rows::<F, _>(self),
);
RowMajorMatrix::new(
rows.into_iter().flatten().collect::<Vec<_>>(),
NUM_MEMORY_LOCAL_INIT_COLS,
)
}
fn included(&self, shard: &Self::Record) -> bool {
if let Some(shape) = shard.shape.as_ref() {
shape.included::<F, _>(self)
} else {
shard.get_local_mem_events().nth(0).is_some()
}
}
fn commit_scope(&self) -> InteractionScope {
InteractionScope::Global
}
}
impl<AB> Air<AB> for MemoryLocalChip
where
AB: SP1AirBuilder,
{
fn eval(&self, builder: &mut AB) {
let main = builder.main();
let local = main.row_slice(0);
let local: &MemoryLocalCols<AB::Var> = (*local).borrow();
for local in local.memory_local_entries.iter() {
builder.assert_eq(
local.is_real * local.is_real * local.is_real,
local.is_real * local.is_real * local.is_real,
);
for scope in [InteractionScope::Global, InteractionScope::Local] {
let mut values =
vec![local.initial_shard.into(), local.initial_clk.into(), local.addr.into()];
values.extend(local.initial_value.map(Into::into));
builder.receive(
AirInteraction::new(
values.clone(),
local.is_real.into(),
InteractionKind::Memory,
),
scope,
);
let mut values =
vec![local.final_shard.into(), local.final_clk.into(), local.addr.into()];
values.extend(local.final_value.map(Into::into));
builder.send(
AirInteraction::new(
values.clone(),
local.is_real.into(),
InteractionKind::Memory,
),
scope,
);
}
}
}
}
#[cfg(test)]
mod tests {
use p3_baby_bear::BabyBear;
use p3_matrix::dense::RowMajorMatrix;
use sp1_core_executor::{programs::tests::simple_program, ExecutionRecord, Executor};
use sp1_stark::{
air::{InteractionScope, MachineAir},
baby_bear_poseidon2::BabyBearPoseidon2,
debug_interactions_with_all_chips, InteractionKind, SP1CoreOpts, StarkMachine,
};
use crate::{
memory::MemoryLocalChip, riscv::RiscvAir,
syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
};
#[test]
fn test_local_memory_generate_trace() {
let program = simple_program();
let mut runtime = Executor::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let shard = runtime.records[0].clone();
let chip: MemoryLocalChip = MemoryLocalChip::new();
let trace: RowMajorMatrix<BabyBear> =
chip.generate_trace(&shard, &mut ExecutionRecord::default());
println!("{:?}", trace.values);
for mem_event in shard.global_memory_finalize_events {
println!("{:?}", mem_event);
}
}
#[test]
fn test_memory_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Executor::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let machine: StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
RiscvAir::machine(BabyBearPoseidon2::new());
let (pkey, _) = machine.setup(&program_clone);
let opts = SP1CoreOpts::default();
machine.generate_dependencies(&mut runtime.records, &opts, None);
let shards = runtime.records;
for shard in shards.clone() {
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&[shard],
vec![InteractionKind::Memory],
InteractionScope::Local,
);
}
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&shards,
vec![InteractionKind::Memory],
InteractionScope::Global,
);
}
#[test]
fn test_byte_lookup_interactions() {
setup_logger();
let program = sha_extend_program();
let program_clone = program.clone();
let mut runtime = Executor::new(program, SP1CoreOpts::default());
runtime.run().unwrap();
let machine = RiscvAir::machine(BabyBearPoseidon2::new());
let (pkey, _) = machine.setup(&program_clone);
let opts = SP1CoreOpts::default();
machine.generate_dependencies(&mut runtime.records, &opts, None);
let shards = runtime.records;
for shard in shards.clone() {
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&[shard],
vec![InteractionKind::Memory],
InteractionScope::Local,
);
}
debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
&machine,
&pkey,
&shards,
vec![InteractionKind::Byte],
InteractionScope::Global,
);
}
}