sp1_core_machine/memory/
local.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5
6use crate::utils::{next_power_of_two, zeroed_f_vec};
7
8use p3_air::{Air, BaseAir};
9use p3_field::{AbstractField, PrimeField32};
10use p3_matrix::{dense::RowMajorMatrix, Matrix};
11use p3_maybe_rayon::prelude::{
12    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
13};
14use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord, Program};
15use sp1_derive::AlignedBorrow;
16use sp1_stark::{
17    air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
18    InteractionKind, Word,
19};
20
21pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 4;
22// pub const NUM_LOCAL_MEMORY_INTERACTIONS_PER_ROW: usize = NUM_LOCAL_MEMORY_ENTRIES_PER_ROW * 2;
23pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::<MemoryLocalCols<u8>>();
24
25// const_assert!(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW_EXEC == NUM_LOCAL_MEMORY_ENTRIES_PER_ROW);
26
27#[derive(AlignedBorrow, Clone, Copy)]
28#[repr(C)]
29pub struct SingleMemoryLocal<T: Copy> {
30    /// The address of the memory access.
31    pub addr: T,
32
33    /// The initial shard of the memory access.
34    pub initial_shard: T,
35
36    /// The final shard of the memory access.
37    pub final_shard: T,
38
39    /// The initial clk of the memory access.
40    pub initial_clk: T,
41
42    /// The final clk of the memory access.
43    pub final_clk: T,
44
45    /// The initial value of the memory access.
46    pub initial_value: Word<T>,
47
48    /// The final value of the memory access.
49    pub final_value: Word<T>,
50
51    /// Whether the memory access is a real access.
52    pub is_real: T,
53}
54
55#[derive(AlignedBorrow, Clone, Copy)]
56#[repr(C)]
57pub struct MemoryLocalCols<T: Copy> {
58    memory_local_entries: [SingleMemoryLocal<T>; NUM_LOCAL_MEMORY_ENTRIES_PER_ROW],
59}
60
61pub struct MemoryLocalChip {}
62
63impl MemoryLocalChip {
64    /// Creates a new memory chip with a certain type.
65    pub const fn new() -> Self {
66        Self {}
67    }
68}
69
70impl<F> BaseAir<F> for MemoryLocalChip {
71    fn width(&self) -> usize {
72        NUM_MEMORY_LOCAL_INIT_COLS
73    }
74}
75
76fn nb_rows(count: usize) -> usize {
77    if NUM_LOCAL_MEMORY_ENTRIES_PER_ROW > 1 {
78        count.div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
79    } else {
80        count
81    }
82}
83
84impl<F: PrimeField32> MachineAir<F> for MemoryLocalChip {
85    type Record = ExecutionRecord;
86
87    type Program = Program;
88
89    fn name(&self) -> String {
90        "MemoryLocal".to_string()
91    }
92
93    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
94        let mut events = Vec::new();
95
96        input.get_local_mem_events().for_each(|mem_event| {
97            events.push(GlobalInteractionEvent {
98                message: [
99                    mem_event.initial_mem_access.shard,
100                    mem_event.initial_mem_access.timestamp,
101                    mem_event.addr,
102                    mem_event.initial_mem_access.value & 255,
103                    (mem_event.initial_mem_access.value >> 8) & 255,
104                    (mem_event.initial_mem_access.value >> 16) & 255,
105                    (mem_event.initial_mem_access.value >> 24) & 255,
106                ],
107                is_receive: true,
108                kind: InteractionKind::Memory as u8,
109            });
110            events.push(GlobalInteractionEvent {
111                message: [
112                    mem_event.final_mem_access.shard,
113                    mem_event.final_mem_access.timestamp,
114                    mem_event.addr,
115                    mem_event.final_mem_access.value & 255,
116                    (mem_event.final_mem_access.value >> 8) & 255,
117                    (mem_event.final_mem_access.value >> 16) & 255,
118                    (mem_event.final_mem_access.value >> 24) & 255,
119                ],
120                is_receive: false,
121                kind: InteractionKind::Memory as u8,
122            });
123        });
124
125        output.global_interaction_events.extend(events);
126    }
127
128    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
129        let count = input.get_local_mem_events().count();
130        let nb_rows = nb_rows(count);
131        let size_log2 = input.fixed_log2_rows::<F, _>(self);
132        Some(next_power_of_two(nb_rows, size_log2))
133    }
134
135    fn generate_trace(
136        &self,
137        input: &Self::Record,
138        _output: &mut Self::Record,
139    ) -> RowMajorMatrix<F> {
140        // Generate the trace rows for each event.
141        let events = input.get_local_mem_events().collect::<Vec<_>>();
142        let nb_rows = nb_rows(events.len());
143        let padded_nb_rows = <MemoryLocalChip as MachineAir<F>>::num_rows(self, input).unwrap();
144        let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
145        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
146
147        let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
148            .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
149            .collect::<Vec<_>>();
150
151        chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
152            rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
153                let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
154
155                let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
156                for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
157                    let cols = &mut cols.memory_local_entries[k];
158                    if idx + k < events.len() {
159                        let event = &events[idx + k];
160                        cols.addr = F::from_canonical_u32(event.addr);
161                        cols.initial_shard = F::from_canonical_u32(event.initial_mem_access.shard);
162                        cols.final_shard = F::from_canonical_u32(event.final_mem_access.shard);
163                        cols.initial_clk =
164                            F::from_canonical_u32(event.initial_mem_access.timestamp);
165                        cols.final_clk = F::from_canonical_u32(event.final_mem_access.timestamp);
166                        cols.initial_value = event.initial_mem_access.value.into();
167                        cols.final_value = event.final_mem_access.value.into();
168                        cols.is_real = F::one();
169                    }
170                }
171            });
172        });
173
174        // Convert the trace to a row major matrix.
175        RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
176    }
177
178    fn included(&self, shard: &Self::Record) -> bool {
179        if let Some(shape) = shard.shape.as_ref() {
180            shape.included::<F, _>(self)
181        } else {
182            shard.get_local_mem_events().nth(0).is_some()
183        }
184    }
185
186    fn commit_scope(&self) -> InteractionScope {
187        InteractionScope::Local
188    }
189}
190
191impl<AB> Air<AB> for MemoryLocalChip
192where
193    AB: SP1AirBuilder,
194{
195    fn eval(&self, builder: &mut AB) {
196        let main = builder.main();
197        let local = main.row_slice(0);
198        let local: &MemoryLocalCols<AB::Var> = (*local).borrow();
199
200        for local in local.memory_local_entries.iter() {
201            // Constrain that `local.is_real` is boolean.
202            builder.assert_bool(local.is_real);
203
204            builder.assert_eq(
205                local.is_real * local.is_real * local.is_real,
206                local.is_real * local.is_real * local.is_real,
207            );
208
209            let mut values =
210                vec![local.initial_shard.into(), local.initial_clk.into(), local.addr.into()];
211            values.extend(local.initial_value.map(Into::into));
212            builder.receive(
213                AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
214                InteractionScope::Local,
215            );
216
217            // Send the "receive interaction" to the global table.
218            builder.send(
219                AirInteraction::new(
220                    vec![
221                        local.initial_shard.into(),
222                        local.initial_clk.into(),
223                        local.addr.into(),
224                        local.initial_value[0].into(),
225                        local.initial_value[1].into(),
226                        local.initial_value[2].into(),
227                        local.initial_value[3].into(),
228                        AB::Expr::zero(),
229                        AB::Expr::one(),
230                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
231                    ],
232                    local.is_real.into(),
233                    InteractionKind::Global,
234                ),
235                InteractionScope::Local,
236            );
237
238            // Send the "send interaction" to the global table.
239            builder.send(
240                AirInteraction::new(
241                    vec![
242                        local.final_shard.into(),
243                        local.final_clk.into(),
244                        local.addr.into(),
245                        local.final_value[0].into(),
246                        local.final_value[1].into(),
247                        local.final_value[2].into(),
248                        local.final_value[3].into(),
249                        AB::Expr::one(),
250                        AB::Expr::zero(),
251                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
252                    ],
253                    local.is_real.into(),
254                    InteractionKind::Global,
255                ),
256                InteractionScope::Local,
257            );
258
259            let mut values =
260                vec![local.final_shard.into(), local.final_clk.into(), local.addr.into()];
261            values.extend(local.final_value.map(Into::into));
262            builder.send(
263                AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
264                InteractionScope::Local,
265            );
266        }
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    #![allow(clippy::print_stdout)]
273
274    use crate::{
275        memory::MemoryLocalChip, programs::tests::*, riscv::RiscvAir,
276        syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
277    };
278    use p3_baby_bear::BabyBear;
279    use p3_matrix::dense::RowMajorMatrix;
280    use sp1_core_executor::{ExecutionRecord, Executor};
281    use sp1_stark::{
282        air::{InteractionScope, MachineAir},
283        baby_bear_poseidon2::BabyBearPoseidon2,
284        debug_interactions_with_all_chips, InteractionKind, SP1CoreOpts, StarkMachine,
285    };
286
287    #[test]
288    fn test_local_memory_generate_trace() {
289        let program = simple_program();
290        let mut runtime = Executor::new(program, SP1CoreOpts::default());
291        runtime.run().unwrap();
292        let shard = runtime.records[0].clone();
293
294        let chip: MemoryLocalChip = MemoryLocalChip::new();
295
296        let trace: RowMajorMatrix<BabyBear> =
297            chip.generate_trace(&shard, &mut ExecutionRecord::default());
298        println!("{:?}", trace.values);
299
300        for mem_event in shard.global_memory_finalize_events {
301            println!("{mem_event:?}");
302        }
303    }
304
305    #[test]
306    fn test_memory_lookup_interactions() {
307        setup_logger();
308        let program = sha_extend_program();
309        let program_clone = program.clone();
310        let mut runtime = Executor::new(program, SP1CoreOpts::default());
311        runtime.run().unwrap();
312        let machine: StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
313            RiscvAir::machine(BabyBearPoseidon2::new());
314        let (pkey, _) = machine.setup(&program_clone);
315        let opts = SP1CoreOpts::default();
316        machine.generate_dependencies(
317            &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
318            &opts,
319            None,
320        );
321
322        let shards = runtime.records;
323        for shard in shards.clone() {
324            debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
325                &machine,
326                &pkey,
327                &[*shard],
328                vec![InteractionKind::Memory],
329                InteractionScope::Local,
330            );
331        }
332        debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
333            &machine,
334            &pkey,
335            &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
336            vec![InteractionKind::Memory],
337            InteractionScope::Global,
338        );
339    }
340
341    #[test]
342    fn test_byte_lookup_interactions() {
343        setup_logger();
344        let program = sha_extend_program();
345        let program_clone = program.clone();
346        let mut runtime = Executor::new(program, SP1CoreOpts::default());
347        runtime.run().unwrap();
348        let machine = RiscvAir::machine(BabyBearPoseidon2::new());
349        let (pkey, _) = machine.setup(&program_clone);
350        let opts = SP1CoreOpts::default();
351        machine.generate_dependencies(
352            &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
353            &opts,
354            None,
355        );
356
357        let shards = runtime.records;
358        for shard in shards.clone() {
359            debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
360                &machine,
361                &pkey,
362                &[*shard],
363                vec![InteractionKind::Memory],
364                InteractionScope::Local,
365            );
366        }
367        debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
368            &machine,
369            &pkey,
370            &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
371            vec![InteractionKind::Byte],
372            InteractionScope::Global,
373        );
374    }
375
376    #[cfg(feature = "sys")]
377    fn get_test_execution_record() -> ExecutionRecord {
378        use p3_field::PrimeField32;
379        use rand::{thread_rng, Rng};
380        use sp1_core_executor::events::{MemoryLocalEvent, MemoryRecord};
381
382        let cpu_local_memory_access = (0..=255)
383            .flat_map(|_| {
384                [{
385                    let addr = thread_rng().gen_range(0..BabyBear::ORDER_U32);
386                    let init_value = thread_rng().gen_range(0..u32::MAX);
387                    let init_shard = thread_rng().gen_range(0..(1u32 << 16));
388                    let init_timestamp = thread_rng().gen_range(0..(1u32 << 24));
389                    let final_value = thread_rng().gen_range(0..u32::MAX);
390                    let final_timestamp = thread_rng().gen_range(0..(1u32 << 24));
391                    let final_shard = thread_rng().gen_range(0..(1u32 << 16));
392                    MemoryLocalEvent {
393                        addr,
394                        initial_mem_access: MemoryRecord {
395                            shard: init_shard,
396                            timestamp: init_timestamp,
397                            value: init_value,
398                        },
399                        final_mem_access: MemoryRecord {
400                            shard: final_shard,
401                            timestamp: final_timestamp,
402                            value: final_value,
403                        },
404                    }
405                }]
406            })
407            .collect::<Vec<_>>();
408        ExecutionRecord { cpu_local_memory_access, ..Default::default() }
409    }
410
411    #[cfg(feature = "sys")]
412    #[test]
413    fn test_generate_trace_ffi_eq_rust() {
414        use p3_matrix::Matrix;
415
416        let record = get_test_execution_record();
417        let chip = MemoryLocalChip::new();
418        let trace: RowMajorMatrix<BabyBear> =
419            chip.generate_trace(&record, &mut ExecutionRecord::default());
420        let trace_ffi = generate_trace_ffi(&record, trace.height());
421
422        assert_eq!(trace_ffi, trace);
423    }
424
425    #[cfg(feature = "sys")]
426    fn generate_trace_ffi(input: &ExecutionRecord, height: usize) -> RowMajorMatrix<BabyBear> {
427        use std::borrow::BorrowMut;
428
429        use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
430
431        use crate::{
432            memory::{
433                MemoryLocalCols, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, NUM_MEMORY_LOCAL_INIT_COLS,
434            },
435            utils::zeroed_f_vec,
436        };
437
438        type F = BabyBear;
439        // Generate the trace rows for each event.
440        let events = input.get_local_mem_events().collect::<Vec<_>>();
441        let nb_rows = events.len().div_ceil(4);
442        let padded_nb_rows = height;
443        let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
444        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
445
446        let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
447            .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
448            .collect::<Vec<_>>();
449
450        chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
451            rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
452                let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
453                let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
454                for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
455                    let cols = &mut cols.memory_local_entries[k];
456                    if idx + k < events.len() {
457                        unsafe {
458                            crate::sys::memory_local_event_to_row_babybear(events[idx + k], cols);
459                        }
460                    }
461                }
462            });
463        });
464
465        // Convert the trace to a row major matrix.
466        RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
467    }
468}