sp1_core_machine/memory/local.rs
1use std::{
2 borrow::{Borrow, BorrowMut},
3 mem::{size_of, MaybeUninit},
4};
5
6use crate::{air::WordAirBuilder, utils::next_multiple_of_32};
7use slop_air::{Air, BaseAir};
8use slop_algebra::{AbstractField, PrimeField32};
9use slop_matrix::Matrix;
10use slop_maybe_rayon::prelude::{
11 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
12};
13use sp1_core_executor::{
14 events::{ByteRecord, GlobalInteractionEvent},
15 ExecutionRecord, Program,
16};
17use sp1_derive::AlignedBorrow;
18use sp1_hypercube::{
19 air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
20 InteractionKind, Word,
21};
22use struct_reflection::{StructReflection, StructReflectionHelper};
23
24pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 1;
25pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::<MemoryLocalCols<u8>>();
26
27#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
28#[repr(C)]
29pub struct SingleMemoryLocal<T: Copy> {
30 /// The address of the memory access.
31 pub addr: [T; 3],
32
33 /// The high bits of initial clk of the memory access.
34 pub initial_clk_high: T,
35
36 /// The high bits of final clk of the memory access.
37 pub final_clk_high: T,
38
39 /// The low bits of initial clk of the memory access.
40 pub initial_clk_low: T,
41
42 /// The low bits of final clk of the memory access.
43 pub final_clk_low: T,
44
45 /// The initial value of the memory access.
46 pub initial_value: Word<T>,
47
48 /// The final value of the memory access.
49 pub final_value: Word<T>,
50
51 /// Lower half of third limb of the initial value
52 pub initial_value_lower: T,
53
54 /// Upper half of third limb of the initial value
55 pub initial_value_upper: T,
56
57 /// Lower half of third limb of the final value
58 pub final_value_lower: T,
59
60 /// Upper half of third limb of the final value
61 pub final_value_upper: T,
62
63 /// Whether the memory access is a real access.
64 pub is_real: T,
65}
66
67#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
68#[repr(C)]
69pub struct MemoryLocalCols<T: Copy> {
70 memory_local_entries: [SingleMemoryLocal<T>; NUM_LOCAL_MEMORY_ENTRIES_PER_ROW],
71}
72
73pub struct MemoryLocalChip {}
74
75impl MemoryLocalChip {
76 /// Creates a new memory chip with a certain type.
77 pub const fn new() -> Self {
78 Self {}
79 }
80}
81
82impl<F> BaseAir<F> for MemoryLocalChip {
83 fn width(&self) -> usize {
84 NUM_MEMORY_LOCAL_INIT_COLS
85 }
86}
87
88fn nb_rows(count: usize) -> usize {
89 if NUM_LOCAL_MEMORY_ENTRIES_PER_ROW > 1 {
90 count.div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
91 } else {
92 count
93 }
94}
95
96impl<F: PrimeField32> MachineAir<F> for MemoryLocalChip {
97 type Record = ExecutionRecord;
98
99 type Program = Program;
100
101 fn name(&self) -> &'static str {
102 "MemoryLocal"
103 }
104
105 fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
106 let mut events = Vec::new();
107
108 input.get_local_mem_events().for_each(|mem_event| {
109 let mut blu = Vec::with_capacity(10); // 1 + 4 + 1 + 4
110 let initial_value_byte0 = ((mem_event.initial_mem_access.value >> 32) & 0xFF) as u32;
111 let initial_value_byte1 = ((mem_event.initial_mem_access.value >> 40) & 0xFF) as u32;
112 blu.add_u8_range_check(initial_value_byte0 as u8, initial_value_byte1 as u8);
113 blu.add_u16_range_checks_field::<F>(&Word::from(mem_event.initial_mem_access.value).0);
114
115 events.push(GlobalInteractionEvent {
116 message: [
117 (mem_event.initial_mem_access.timestamp >> 24) as u32,
118 (mem_event.initial_mem_access.timestamp & 0xFFFFFF) as u32,
119 (mem_event.addr & 0xFFFF) as u32,
120 ((mem_event.addr >> 16) & 0xFFFF) as u32,
121 ((mem_event.addr >> 32) & 0xFFFF) as u32,
122 (mem_event.initial_mem_access.value & 0xFFFF) as u32
123 + (1 << 16) * initial_value_byte0,
124 ((mem_event.initial_mem_access.value >> 16) & 0xFFFF) as u32
125 + (1 << 16) * initial_value_byte1,
126 ((mem_event.initial_mem_access.value >> 48) & 0xFFFF) as u32,
127 ],
128 is_receive: true,
129 kind: InteractionKind::Memory as u8,
130 });
131
132 let final_value_byte0 = ((mem_event.final_mem_access.value >> 32) & 0xFF) as u32;
133 let final_value_byte1 = ((mem_event.final_mem_access.value >> 40) & 0xFF) as u32;
134 blu.add_u8_range_check(final_value_byte0 as u8, final_value_byte1 as u8);
135 blu.add_u16_range_checks_field::<F>(&Word::from(mem_event.final_mem_access.value).0);
136 events.push(GlobalInteractionEvent {
137 message: [
138 (mem_event.final_mem_access.timestamp >> 24) as u32,
139 (mem_event.final_mem_access.timestamp & 0xFFFFFF) as u32,
140 (mem_event.addr & 0xFFFF) as u32,
141 ((mem_event.addr >> 16) & 0xFFFF) as u32,
142 ((mem_event.addr >> 32) & 0xFFFF) as u32,
143 (mem_event.final_mem_access.value & 0xFFFF) as u32
144 + (1 << 16) * final_value_byte0,
145 ((mem_event.final_mem_access.value >> 16) & 0xFFFF) as u32
146 + (1 << 16) * final_value_byte1,
147 ((mem_event.final_mem_access.value >> 48) & 0xFFFF) as u32,
148 ],
149 is_receive: false,
150 kind: InteractionKind::Memory as u8,
151 });
152
153 output.add_byte_lookup_events(blu);
154 });
155
156 output.global_interaction_events.extend(events);
157 }
158
159 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
160 let count = input.get_local_mem_events().count();
161 let nb_rows = nb_rows(count);
162 let size_log2 = input.fixed_log2_rows::<F, _>(self);
163 Some(next_multiple_of_32(nb_rows, size_log2))
164 }
165
166 fn generate_trace_into(
167 &self,
168 input: &ExecutionRecord,
169 _output: &mut ExecutionRecord,
170 buffer: &mut [MaybeUninit<F>],
171 ) {
172 // Generate the trace rows for each event.
173 let events = input.get_local_mem_events().collect::<Vec<_>>();
174 let nb_rows = nb_rows(events.len());
175 let padded_nb_rows = <MemoryLocalChip as MachineAir<F>>::num_rows(self, input).unwrap();
176 let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
177
178 unsafe {
179 let padding_start = nb_rows * NUM_MEMORY_LOCAL_INIT_COLS;
180 let padding_size = (padded_nb_rows - nb_rows) * NUM_MEMORY_LOCAL_INIT_COLS;
181 if padding_size > 0 {
182 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
183 }
184 }
185
186 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
187 let values = unsafe {
188 core::slice::from_raw_parts_mut(buffer_ptr, nb_rows * NUM_MEMORY_LOCAL_INIT_COLS)
189 };
190
191 let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
192 .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
193 .collect::<Vec<_>>();
194
195 chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
196 rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
197 let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
198
199 let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
200 for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
201 let cols = &mut cols.memory_local_entries[k];
202 if idx + k < events.len() {
203 let event = &events[idx + k];
204 cols.addr = [
205 F::from_canonical_u64(event.addr & 0xFFFF),
206 F::from_canonical_u64((event.addr >> 16) & 0xFFFF),
207 F::from_canonical_u64((event.addr >> 32) & 0xFFFF),
208 ];
209 cols.initial_clk_high = F::from_canonical_u32(
210 (event.initial_mem_access.timestamp >> 24) as u32,
211 );
212 cols.final_clk_high =
213 F::from_canonical_u32((event.final_mem_access.timestamp >> 24) as u32);
214 cols.initial_clk_low = F::from_canonical_u32(
215 (event.initial_mem_access.timestamp & 0xFFFFFF) as u32,
216 );
217 cols.final_clk_low = F::from_canonical_u32(
218 (event.final_mem_access.timestamp & 0xFFFFFF) as u32,
219 );
220 cols.initial_value = event.initial_mem_access.value.into();
221 cols.final_value = event.final_mem_access.value.into();
222 cols.is_real = F::one();
223 // split the third limb of initial value into 2 limbs of 8 bits
224 let initial_value_byte0 = (event.initial_mem_access.value >> 32) & 0xFF;
225 let initial_value_byte1 = (event.initial_mem_access.value >> 40) & 0xFF;
226 cols.initial_value_lower =
227 F::from_canonical_u32(initial_value_byte0 as u32);
228 cols.initial_value_upper =
229 F::from_canonical_u32(initial_value_byte1 as u32);
230 let final_value_byte0 = (event.final_mem_access.value >> 32) & 0xFF;
231 let final_value_byte1 = (event.final_mem_access.value >> 40) & 0xFF;
232 cols.final_value_lower = F::from_canonical_u32(final_value_byte0 as u32);
233 cols.final_value_upper = F::from_canonical_u32(final_value_byte1 as u32);
234 }
235 }
236 });
237 });
238 }
239
240 fn included(&self, shard: &Self::Record) -> bool {
241 if let Some(shape) = shard.shape.as_ref() {
242 shape.included::<F, _>(self)
243 } else {
244 shard.get_local_mem_events().nth(0).is_some()
245 }
246 }
247
248 fn column_names(&self) -> Vec<String> {
249 MemoryLocalCols::<F>::struct_reflection().unwrap()
250 }
251}
252
253impl<AB> Air<AB> for MemoryLocalChip
254where
255 AB: SP1AirBuilder,
256{
257 fn eval(&self, builder: &mut AB) {
258 let main = builder.main();
259 let local = main.row_slice(0);
260 let local: &MemoryLocalCols<AB::Var> = (*local).borrow();
261
262 for local in local.memory_local_entries.iter() {
263 // Constrain that `local.is_real` is boolean.
264 builder.assert_bool(local.is_real);
265
266 builder.assert_eq(
267 local.is_real * local.is_real * local.is_real,
268 local.is_real * local.is_real * local.is_real,
269 );
270
271 // Constrain that value_lower and value_upper are the lower and upper byte of the limb.
272 builder.assert_eq(
273 local.initial_value.0[2],
274 local.initial_value_lower
275 + local.initial_value_upper * AB::F::from_canonical_u32(1 << 8),
276 );
277 builder.slice_range_check_u8(
278 &[local.initial_value_lower, local.initial_value_upper],
279 local.is_real,
280 );
281 builder.slice_range_check_u16(&local.initial_value.0, local.is_real);
282
283 let mut values = vec![local.initial_clk_high.into(), local.initial_clk_low.into()];
284 values.extend(local.addr.map(Into::into));
285 values.extend(local.initial_value.map(Into::into));
286 builder.receive(
287 AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
288 InteractionScope::Local,
289 );
290
291 // Send the "receive interaction" to the global table.
292 builder.send(
293 AirInteraction::new(
294 vec![
295 local.initial_clk_high.into(),
296 local.initial_clk_low.into(),
297 local.addr[0].into(),
298 local.addr[1].into(),
299 local.addr[2].into(),
300 local.initial_value.0[0]
301 + local.initial_value_lower * AB::F::from_canonical_u32(1 << 16),
302 local.initial_value.0[1]
303 + local.initial_value_upper * AB::F::from_canonical_u32(1 << 16),
304 local.initial_value.0[3].into(),
305 AB::Expr::zero(),
306 AB::Expr::one(),
307 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
308 ],
309 local.is_real.into(),
310 InteractionKind::Global,
311 ),
312 InteractionScope::Local,
313 );
314
315 // Constrain that value_lower and value_upper are the lower and upper byte of the limb.
316 builder.assert_eq(
317 local.final_value.0[2],
318 local.final_value_lower
319 + local.final_value_upper * AB::F::from_canonical_u32(1 << 8),
320 );
321 builder.slice_range_check_u8(
322 &[local.final_value_lower, local.final_value_upper],
323 local.is_real,
324 );
325 builder.slice_range_check_u16(&local.final_value.0, local.is_real);
326
327 let mut values = vec![local.final_clk_high.into(), local.final_clk_low.into()];
328 values.extend(local.addr.map(Into::into));
329 values.extend(local.final_value.map(Into::into));
330 builder.send(
331 AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
332 InteractionScope::Local,
333 );
334
335 // Send the "send interaction" to the global table.
336 builder.send(
337 AirInteraction::new(
338 vec![
339 local.final_clk_high.into(),
340 local.final_clk_low.into(),
341 local.addr[0].into(),
342 local.addr[1].into(),
343 local.addr[2].into(),
344 local.final_value.0[0]
345 + local.final_value_lower * AB::F::from_canonical_u32(1 << 16),
346 local.final_value.0[1]
347 + local.final_value_upper * AB::F::from_canonical_u32(1 << 16),
348 local.final_value.0[3].into(),
349 AB::Expr::one(),
350 AB::Expr::zero(),
351 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
352 ],
353 local.is_real.into(),
354 InteractionKind::Global,
355 ),
356 InteractionScope::Local,
357 );
358 }
359 }
360}
361
362// #[cfg(test)]
363// mod tests {
364// #![allow(clippy::print_stdout)]
365
366// use crate::programs::tests::*;
367// use crate::{
368// memory::MemoryLocalChip, riscv::RiscvAir,
369// syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
370// };
371// use sp1_primitives::SP1Field;
372// use slop_matrix::dense::RowMajorMatrix;
373// use sp1_core_executor::{ExecutionRecord, Executor, Trace};
374// use sp1_hypercube::{
375// air::{InteractionScope, MachineAir},
376// koala_bear_poseidon2::SP1InnerPcs,
377// debug_interactions_with_all_chips, InteractionKind, SP1CoreOpts, StarkMachine,
378// };
379
380// #[test]
381// fn test_local_memory_generate_trace() {
382// let program = simple_program();
383// let mut runtime = Executor::new(program, SP1CoreOpts::default());
384// runtime.run::<Trace>().unwrap();
385// let shard = runtime.records[0].clone();
386
387// let chip: MemoryLocalChip = MemoryLocalChip::new();
388
389// let trace: RowMajorMatrix<SP1Field> =
390// chip.generate_trace(&shard, &mut ExecutionRecord::default());
391// println!("{:?}", trace.values);
392
393// for mem_event in shard.global_memory_finalize_events {
394// println!("{mem_event:?}");
395// }
396// }
397
398// #[test]
399// fn test_memory_lookup_interactions() {
400// setup_logger();
401// let program = sha_extend_program();
402// let program_clone = program.clone();
403// let mut runtime = Executor::new(program, SP1CoreOpts::default());
404// runtime.run::<Trace>().unwrap();
405// let machine: StarkMachine<SP1InnerPcs, RiscvAir<SP1Field>> =
406// RiscvAir::machine(SP1InnerPcs::new());
407// let (pkey, _) = machine.setup(&program_clone);
408// let opts = SP1CoreOpts::default();
409// machine.generate_dependencies(
410// &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
411// &opts,
412// None,
413// );
414
415// let shards = runtime.records;
416// for shard in shards.clone() {
417// debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
418// &machine,
419// &pkey,
420// &[*shard],
421// vec![InteractionKind::Memory],
422// InteractionScope::Local,
423// );
424// }
425// debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
426// &machine,
427// &pkey,
428// &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
429// vec![InteractionKind::Memory],
430// InteractionScope::Global,
431// );
432// }
433
434// #[test]
435// fn test_byte_lookup_interactions() {
436// setup_logger();
437// let program = sha_extend_program();
438// let program_clone = program.clone();
439// let mut runtime = Executor::new(program, SP1CoreOpts::default());
440// runtime.run::<Trace>().unwrap();
441// let machine = RiscvAir::machine(SP1InnerPcs::new());
442// let (pkey, _) = machine.setup(&program_clone);
443// let opts = SP1CoreOpts::default();
444// machine.generate_dependencies(
445// &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
446// &opts,
447// None,
448// );
449
450// let shards = runtime.records;
451// for shard in shards.clone() {
452// debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
453// &machine,
454// &pkey,
455// &[*shard],
456// vec![InteractionKind::Memory],
457// InteractionScope::Local,
458// );
459// }
460// debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
461// &machine,
462// &pkey,
463// &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
464// vec![InteractionKind::Byte],
465// InteractionScope::Global,
466// );
467// }
468
469// #[cfg(feature = "sys")]
470// fn get_test_execution_record() -> ExecutionRecord {
471// use slop_algebra::PrimeField32;
472// use rand::{thread_rng, Rng};
473// use sp1_core_executor::events::{MemoryLocalEvent, MemoryRecord};
474
475// let cpu_local_memory_access = (0..=255)
476// .flat_map(|_| {
477// [{
478// let addr = thread_rng().gen_range(0..SP1Field::ORDER_U32);
479// let init_value = thread_rng().gen_range(0..u32::MAX);
480// let init_shard = thread_rng().gen_range(0..(1u32 << 16));
481// let init_timestamp = thread_rng().gen_range(0..(1u32 << 24));
482// let final_value = thread_rng().gen_range(0..u32::MAX);
483// let final_timestamp = thread_rng().gen_range(0..(1u32 << 24));
484// let final_shard = thread_rng().gen_range(0..(1u32 << 16));
485// MemoryLocalEvent {
486// addr,
487// initial_mem_access: MemoryRecord {
488// shard: init_shard,
489// timestamp: init_timestamp,
490// value: init_value,
491// },
492// final_mem_access: MemoryRecord {
493// shard: final_shard,
494// timestamp: final_timestamp,
495// value: final_value,
496// },
497// }
498// }]
499// })
500// .collect::<Vec<_>>();
501// ExecutionRecord { cpu_local_memory_access, ..Default::default() }
502// }
503
504// #[cfg(feature = "sys")]
505// #[test]
506// fn test_generate_trace_ffi_eq_rust() {
507// use slop_matrix::Matrix;
508
509// let record = get_test_execution_record();
510// let chip = MemoryLocalChip::new();
511// let trace: RowMajorMatrix<SP1Field> =
512// chip.generate_trace(&record, &mut ExecutionRecord::default());
513// let trace_ffi = generate_trace_ffi(&record, trace.height());
514
515// assert_eq!(trace_ffi, trace);
516// }
517
518// #[cfg(feature = "sys")]
519// fn generate_trace_ffi(input: &ExecutionRecord, height: usize) -> RowMajorMatrix<SP1Field> {
520// use std::borrow::BorrowMut;
521
522// use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
523
524// use crate::{
525// memory::{
526// MemoryLocalCols, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, NUM_MEMORY_LOCAL_INIT_COLS,
527// },
528// utils::zeroed_f_vec,
529// };
530
531// use sp1_primitives::SP1Field;
532// type F = SP1Field;
533// // Generate the trace rows for each event.
534// let events = input.get_local_mem_events().collect::<Vec<_>>();
535// let nb_rows = events.len().div_ceil(4);
536// let padded_nb_rows = height;
537// let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
538// let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
539
540// let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
541// .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
542// .collect::<Vec<_>>();
543
544// chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
545// rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
546// let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
547// let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
548// for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
549// let cols = &mut cols.memory_local_entries[k];
550// if idx + k < events.len() {
551// unsafe {
552// crate::sys::memory_local_event_to_row_koalabear(events[idx + k],
553// cols); }
554// }
555// }
556// });
557// });
558
559// // Convert the trace to a row major matrix.
560// RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
561// }
562// }