1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5
6use crate::utils::{next_power_of_two, zeroed_f_vec};
7
8use p3_air::{Air, BaseAir};
9use p3_field::{AbstractField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{
12 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
13};
14use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord, Program};
15use sp1_derive::AlignedBorrow;
16use sp1_stark::{
17 air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
18 InteractionKind, Word,
19};
20
21pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 4;
22pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::<MemoryLocalCols<u8>>();
24
25#[derive(AlignedBorrow, Clone, Copy)]
28#[repr(C)]
29pub struct SingleMemoryLocal<T: Copy> {
30 pub addr: T,
32
33 pub initial_shard: T,
35
36 pub final_shard: T,
38
39 pub initial_clk: T,
41
42 pub final_clk: T,
44
45 pub initial_value: Word<T>,
47
48 pub final_value: Word<T>,
50
51 pub is_real: T,
53}
54
55#[derive(AlignedBorrow, Clone, Copy)]
56#[repr(C)]
57pub struct MemoryLocalCols<T: Copy> {
58 memory_local_entries: [SingleMemoryLocal<T>; NUM_LOCAL_MEMORY_ENTRIES_PER_ROW],
59}
60
61pub struct MemoryLocalChip {}
62
63impl MemoryLocalChip {
64 pub const fn new() -> Self {
66 Self {}
67 }
68}
69
70impl<F> BaseAir<F> for MemoryLocalChip {
71 fn width(&self) -> usize {
72 NUM_MEMORY_LOCAL_INIT_COLS
73 }
74}
75
76fn nb_rows(count: usize) -> usize {
77 if NUM_LOCAL_MEMORY_ENTRIES_PER_ROW > 1 {
78 count.div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
79 } else {
80 count
81 }
82}
83
84impl<F: PrimeField32> MachineAir<F> for MemoryLocalChip {
85 type Record = ExecutionRecord;
86
87 type Program = Program;
88
89 fn name(&self) -> String {
90 "MemoryLocal".to_string()
91 }
92
93 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
94 let mut events = Vec::new();
95
96 input.get_local_mem_events().for_each(|mem_event| {
97 events.push(GlobalInteractionEvent {
98 message: [
99 mem_event.initial_mem_access.shard,
100 mem_event.initial_mem_access.timestamp,
101 mem_event.addr,
102 mem_event.initial_mem_access.value & 255,
103 (mem_event.initial_mem_access.value >> 8) & 255,
104 (mem_event.initial_mem_access.value >> 16) & 255,
105 (mem_event.initial_mem_access.value >> 24) & 255,
106 ],
107 is_receive: true,
108 kind: InteractionKind::Memory as u8,
109 });
110 events.push(GlobalInteractionEvent {
111 message: [
112 mem_event.final_mem_access.shard,
113 mem_event.final_mem_access.timestamp,
114 mem_event.addr,
115 mem_event.final_mem_access.value & 255,
116 (mem_event.final_mem_access.value >> 8) & 255,
117 (mem_event.final_mem_access.value >> 16) & 255,
118 (mem_event.final_mem_access.value >> 24) & 255,
119 ],
120 is_receive: false,
121 kind: InteractionKind::Memory as u8,
122 });
123 });
124
125 output.global_interaction_events.extend(events);
126 }
127
128 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
129 let count = input.get_local_mem_events().count();
130 let nb_rows = nb_rows(count);
131 let size_log2 = input.fixed_log2_rows::<F, _>(self);
132 Some(next_power_of_two(nb_rows, size_log2))
133 }
134
135 fn generate_trace(
136 &self,
137 input: &Self::Record,
138 _output: &mut Self::Record,
139 ) -> RowMajorMatrix<F> {
140 let events = input.get_local_mem_events().collect::<Vec<_>>();
142 let nb_rows = nb_rows(events.len());
143 let padded_nb_rows = <MemoryLocalChip as MachineAir<F>>::num_rows(self, input).unwrap();
144 let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
145 let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
146
147 let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
148 .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
149 .collect::<Vec<_>>();
150
151 chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
152 rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
153 let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
154
155 let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
156 for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
157 let cols = &mut cols.memory_local_entries[k];
158 if idx + k < events.len() {
159 let event = &events[idx + k];
160 cols.addr = F::from_canonical_u32(event.addr);
161 cols.initial_shard = F::from_canonical_u32(event.initial_mem_access.shard);
162 cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard);
163 cols.initial_clk =
164 F::from_canonical_u32(event.initial_mem_access.timestamp);
165 cols.final_clk = F::from_canonical_u32(event.final_mem_access.timestamp);
166 cols.initial_value = event.initial_mem_access.value.into();
167 cols.final_value = event.final_mem_access.value.into();
168 cols.is_real = F::one();
169 }
170 }
171 });
172 });
173
174 RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
176 }
177
178 fn included(&self, shard: &Self::Record) -> bool {
179 if let Some(shape) = shard.shape.as_ref() {
180 shape.included::<F, _>(self)
181 } else {
182 shard.get_local_mem_events().nth(0).is_some()
183 }
184 }
185
186 fn commit_scope(&self) -> InteractionScope {
187 InteractionScope::Local
188 }
189}
190
191impl<AB> Air<AB> for MemoryLocalChip
192where
193 AB: SP1AirBuilder,
194{
195 fn eval(&self, builder: &mut AB) {
196 let main = builder.main();
197 let local = main.row_slice(0);
198 let local: &MemoryLocalCols<AB::Var> = (*local).borrow();
199
200 for local in local.memory_local_entries.iter() {
201 builder.assert_bool(local.is_real);
203
204 builder.assert_eq(
205 local.is_real * local.is_real * local.is_real,
206 local.is_real * local.is_real * local.is_real,
207 );
208
209 let mut values =
210 vec![local.initial_shard.into(), local.initial_clk.into(), local.addr.into()];
211 values.extend(local.initial_value.map(Into::into));
212 builder.receive(
213 AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
214 InteractionScope::Local,
215 );
216
217 builder.send(
219 AirInteraction::new(
220 vec![
221 local.initial_shard.into(),
222 local.initial_clk.into(),
223 local.addr.into(),
224 local.initial_value[0].into(),
225 local.initial_value[1].into(),
226 local.initial_value[2].into(),
227 local.initial_value[3].into(),
228 AB::Expr::zero(),
229 AB::Expr::one(),
230 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
231 ],
232 local.is_real.into(),
233 InteractionKind::Global,
234 ),
235 InteractionScope::Local,
236 );
237
238 builder.send(
240 AirInteraction::new(
241 vec![
242 local.final_shard.into(),
243 local.final_clk.into(),
244 local.addr.into(),
245 local.final_value[0].into(),
246 local.final_value[1].into(),
247 local.final_value[2].into(),
248 local.final_value[3].into(),
249 AB::Expr::one(),
250 AB::Expr::zero(),
251 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
252 ],
253 local.is_real.into(),
254 InteractionKind::Global,
255 ),
256 InteractionScope::Local,
257 );
258
259 let mut values =
260 vec![local.final_shard.into(), local.final_clk.into(), local.addr.into()];
261 values.extend(local.final_value.map(Into::into));
262 builder.send(
263 AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
264 InteractionScope::Local,
265 );
266 }
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 #![allow(clippy::print_stdout)]
273
274 use crate::{
275 memory::MemoryLocalChip, programs::tests::*, riscv::RiscvAir,
276 syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
277 };
278 use p3_baby_bear::BabyBear;
279 use p3_matrix::dense::RowMajorMatrix;
280 use sp1_core_executor::{ExecutionRecord, Executor};
281 use sp1_stark::{
282 air::{InteractionScope, MachineAir},
283 baby_bear_poseidon2::BabyBearPoseidon2,
284 debug_interactions_with_all_chips, InteractionKind, SP1CoreOpts, StarkMachine,
285 };
286
287 #[test]
288 fn test_local_memory_generate_trace() {
289 let program = simple_program();
290 let mut runtime = Executor::new(program, SP1CoreOpts::default());
291 runtime.run().unwrap();
292 let shard = runtime.records[0].clone();
293
294 let chip: MemoryLocalChip = MemoryLocalChip::new();
295
296 let trace: RowMajorMatrix<BabyBear> =
297 chip.generate_trace(&shard, &mut ExecutionRecord::default());
298 println!("{:?}", trace.values);
299
300 for mem_event in shard.global_memory_finalize_events {
301 println!("{mem_event:?}");
302 }
303 }
304
305 #[test]
306 fn test_memory_lookup_interactions() {
307 setup_logger();
308 let program = sha_extend_program();
309 let program_clone = program.clone();
310 let mut runtime = Executor::new(program, SP1CoreOpts::default());
311 runtime.run().unwrap();
312 let machine: StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
313 RiscvAir::machine(BabyBearPoseidon2::new());
314 let (pkey, _) = machine.setup(&program_clone);
315 let opts = SP1CoreOpts::default();
316 machine.generate_dependencies(
317 &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
318 &opts,
319 None,
320 );
321
322 let shards = runtime.records;
323 for shard in shards.clone() {
324 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
325 &machine,
326 &pkey,
327 &[*shard],
328 vec![InteractionKind::Memory],
329 InteractionScope::Local,
330 );
331 }
332 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
333 &machine,
334 &pkey,
335 &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
336 vec![InteractionKind::Memory],
337 InteractionScope::Global,
338 );
339 }
340
341 #[test]
342 fn test_byte_lookup_interactions() {
343 setup_logger();
344 let program = sha_extend_program();
345 let program_clone = program.clone();
346 let mut runtime = Executor::new(program, SP1CoreOpts::default());
347 runtime.run().unwrap();
348 let machine = RiscvAir::machine(BabyBearPoseidon2::new());
349 let (pkey, _) = machine.setup(&program_clone);
350 let opts = SP1CoreOpts::default();
351 machine.generate_dependencies(
352 &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
353 &opts,
354 None,
355 );
356
357 let shards = runtime.records;
358 for shard in shards.clone() {
359 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
360 &machine,
361 &pkey,
362 &[*shard],
363 vec![InteractionKind::Memory],
364 InteractionScope::Local,
365 );
366 }
367 debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
368 &machine,
369 &pkey,
370 &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
371 vec![InteractionKind::Byte],
372 InteractionScope::Global,
373 );
374 }
375
376 #[cfg(feature = "sys")]
377 fn get_test_execution_record() -> ExecutionRecord {
378 use p3_field::PrimeField32;
379 use rand::{thread_rng, Rng};
380 use sp1_core_executor::events::{MemoryLocalEvent, MemoryRecord};
381
382 let cpu_local_memory_access = (0..=255)
383 .flat_map(|_| {
384 [{
385 let addr = thread_rng().gen_range(0..BabyBear::ORDER_U32);
386 let init_value = thread_rng().gen_range(0..u32::MAX);
387 let init_shard = thread_rng().gen_range(0..(1u32 << 16));
388 let init_timestamp = thread_rng().gen_range(0..(1u32 << 24));
389 let final_value = thread_rng().gen_range(0..u32::MAX);
390 let final_timestamp = thread_rng().gen_range(0..(1u32 << 24));
391 let final_shard = thread_rng().gen_range(0..(1u32 << 16));
392 MemoryLocalEvent {
393 addr,
394 initial_mem_access: MemoryRecord {
395 shard: init_shard,
396 timestamp: init_timestamp,
397 value: init_value,
398 },
399 final_mem_access: MemoryRecord {
400 shard: final_shard,
401 timestamp: final_timestamp,
402 value: final_value,
403 },
404 }
405 }]
406 })
407 .collect::<Vec<_>>();
408 ExecutionRecord { cpu_local_memory_access, ..Default::default() }
409 }
410
411 #[cfg(feature = "sys")]
412 #[test]
413 fn test_generate_trace_ffi_eq_rust() {
414 use p3_matrix::Matrix;
415
416 let record = get_test_execution_record();
417 let chip = MemoryLocalChip::new();
418 let trace: RowMajorMatrix<BabyBear> =
419 chip.generate_trace(&record, &mut ExecutionRecord::default());
420 let trace_ffi = generate_trace_ffi(&record, trace.height());
421
422 assert_eq!(trace_ffi, trace);
423 }
424
425 #[cfg(feature = "sys")]
426 fn generate_trace_ffi(input: &ExecutionRecord, height: usize) -> RowMajorMatrix<BabyBear> {
427 use std::borrow::BorrowMut;
428
429 use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
430
431 use crate::{
432 memory::{
433 MemoryLocalCols, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, NUM_MEMORY_LOCAL_INIT_COLS,
434 },
435 utils::zeroed_f_vec,
436 };
437
438 type F = BabyBear;
439 let events = input.get_local_mem_events().collect::<Vec<_>>();
441 let nb_rows = events.len().div_ceil(4);
442 let padded_nb_rows = height;
443 let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
444 let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
445
446 let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
447 .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
448 .collect::<Vec<_>>();
449
450 chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
451 rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
452 let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
453 let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
454 for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
455 let cols = &mut cols.memory_local_entries[k];
456 if idx + k < events.len() {
457 unsafe {
458 crate::sys::memory_local_event_to_row_babybear(events[idx + k], cols);
459 }
460 }
461 }
462 });
463 });
464
465 RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
467 }
468}