Skip to main content

sp1_core_machine/memory/
local.rs

1use std::{
2    borrow::{Borrow, BorrowMut},
3    mem::{size_of, MaybeUninit},
4};
5
6use crate::{air::WordAirBuilder, utils::next_multiple_of_32};
7use slop_air::{Air, BaseAir};
8use slop_algebra::{AbstractField, PrimeField32};
9use slop_matrix::Matrix;
10use slop_maybe_rayon::prelude::{
11    IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
12};
13use sp1_core_executor::{
14    events::{ByteRecord, GlobalInteractionEvent},
15    ExecutionRecord, Program,
16};
17use sp1_derive::AlignedBorrow;
18use sp1_hypercube::{
19    air::{AirInteraction, InteractionScope, MachineAir, SP1AirBuilder},
20    InteractionKind, Word,
21};
22use struct_reflection::{StructReflection, StructReflectionHelper};
23
24pub const NUM_LOCAL_MEMORY_ENTRIES_PER_ROW: usize = 1;
25pub(crate) const NUM_MEMORY_LOCAL_INIT_COLS: usize = size_of::<MemoryLocalCols<u8>>();
26
27#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
28#[repr(C)]
29pub struct SingleMemoryLocal<T: Copy> {
30    /// The address of the memory access.
31    pub addr: [T; 3],
32
33    /// The high bits of initial clk of the memory access.
34    pub initial_clk_high: T,
35
36    /// The high bits of final clk of the memory access.
37    pub final_clk_high: T,
38
39    /// The low bits of initial clk of the memory access.
40    pub initial_clk_low: T,
41
42    /// The low bits of final clk of the memory access.
43    pub final_clk_low: T,
44
45    /// The initial value of the memory access.
46    pub initial_value: Word<T>,
47
48    /// The final value of the memory access.
49    pub final_value: Word<T>,
50
51    /// Lower half of third limb of the initial value
52    pub initial_value_lower: T,
53
54    /// Upper half of third limb of the initial value
55    pub initial_value_upper: T,
56
57    /// Lower half of third limb of the final value
58    pub final_value_lower: T,
59
60    /// Upper half of third limb of the final value
61    pub final_value_upper: T,
62
63    /// Whether the memory access is a real access.
64    pub is_real: T,
65}
66
67#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
68#[repr(C)]
69pub struct MemoryLocalCols<T: Copy> {
70    memory_local_entries: [SingleMemoryLocal<T>; NUM_LOCAL_MEMORY_ENTRIES_PER_ROW],
71}
72
73pub struct MemoryLocalChip {}
74
75impl MemoryLocalChip {
76    /// Creates a new memory chip with a certain type.
77    pub const fn new() -> Self {
78        Self {}
79    }
80}
81
82impl<F> BaseAir<F> for MemoryLocalChip {
83    fn width(&self) -> usize {
84        NUM_MEMORY_LOCAL_INIT_COLS
85    }
86}
87
88fn nb_rows(count: usize) -> usize {
89    if NUM_LOCAL_MEMORY_ENTRIES_PER_ROW > 1 {
90        count.div_ceil(NUM_LOCAL_MEMORY_ENTRIES_PER_ROW)
91    } else {
92        count
93    }
94}
95
96impl<F: PrimeField32> MachineAir<F> for MemoryLocalChip {
97    type Record = ExecutionRecord;
98
99    type Program = Program;
100
101    fn name(&self) -> &'static str {
102        "MemoryLocal"
103    }
104
105    fn generate_dependencies(&self, input: &Self::Record, output: &mut Self::Record) {
106        let mut events = Vec::new();
107
108        input.get_local_mem_events().for_each(|mem_event| {
109            let mut blu = Vec::with_capacity(10); // 1 + 4 + 1 + 4
110            let initial_value_byte0 = ((mem_event.initial_mem_access.value >> 32) & 0xFF) as u32;
111            let initial_value_byte1 = ((mem_event.initial_mem_access.value >> 40) & 0xFF) as u32;
112            blu.add_u8_range_check(initial_value_byte0 as u8, initial_value_byte1 as u8);
113            blu.add_u16_range_checks_field::<F>(&Word::from(mem_event.initial_mem_access.value).0);
114
115            events.push(GlobalInteractionEvent {
116                message: [
117                    (mem_event.initial_mem_access.timestamp >> 24) as u32,
118                    (mem_event.initial_mem_access.timestamp & 0xFFFFFF) as u32,
119                    (mem_event.addr & 0xFFFF) as u32,
120                    ((mem_event.addr >> 16) & 0xFFFF) as u32,
121                    ((mem_event.addr >> 32) & 0xFFFF) as u32,
122                    (mem_event.initial_mem_access.value & 0xFFFF) as u32
123                        + (1 << 16) * initial_value_byte0,
124                    ((mem_event.initial_mem_access.value >> 16) & 0xFFFF) as u32
125                        + (1 << 16) * initial_value_byte1,
126                    ((mem_event.initial_mem_access.value >> 48) & 0xFFFF) as u32,
127                ],
128                is_receive: true,
129                kind: InteractionKind::Memory as u8,
130            });
131
132            let final_value_byte0 = ((mem_event.final_mem_access.value >> 32) & 0xFF) as u32;
133            let final_value_byte1 = ((mem_event.final_mem_access.value >> 40) & 0xFF) as u32;
134            blu.add_u8_range_check(final_value_byte0 as u8, final_value_byte1 as u8);
135            blu.add_u16_range_checks_field::<F>(&Word::from(mem_event.final_mem_access.value).0);
136            events.push(GlobalInteractionEvent {
137                message: [
138                    (mem_event.final_mem_access.timestamp >> 24) as u32,
139                    (mem_event.final_mem_access.timestamp & 0xFFFFFF) as u32,
140                    (mem_event.addr & 0xFFFF) as u32,
141                    ((mem_event.addr >> 16) & 0xFFFF) as u32,
142                    ((mem_event.addr >> 32) & 0xFFFF) as u32,
143                    (mem_event.final_mem_access.value & 0xFFFF) as u32
144                        + (1 << 16) * final_value_byte0,
145                    ((mem_event.final_mem_access.value >> 16) & 0xFFFF) as u32
146                        + (1 << 16) * final_value_byte1,
147                    ((mem_event.final_mem_access.value >> 48) & 0xFFFF) as u32,
148                ],
149                is_receive: false,
150                kind: InteractionKind::Memory as u8,
151            });
152
153            output.add_byte_lookup_events(blu);
154        });
155
156        output.global_interaction_events.extend(events);
157    }
158
159    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
160        let count = input.get_local_mem_events().count();
161        let nb_rows = nb_rows(count);
162        let size_log2 = input.fixed_log2_rows::<F, _>(self);
163        Some(next_multiple_of_32(nb_rows, size_log2))
164    }
165
166    fn generate_trace_into(
167        &self,
168        input: &ExecutionRecord,
169        _output: &mut ExecutionRecord,
170        buffer: &mut [MaybeUninit<F>],
171    ) {
172        // Generate the trace rows for each event.
173        let events = input.get_local_mem_events().collect::<Vec<_>>();
174        let nb_rows = nb_rows(events.len());
175        let padded_nb_rows = <MemoryLocalChip as MachineAir<F>>::num_rows(self, input).unwrap();
176        let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
177
178        unsafe {
179            let padding_start = nb_rows * NUM_MEMORY_LOCAL_INIT_COLS;
180            let padding_size = (padded_nb_rows - nb_rows) * NUM_MEMORY_LOCAL_INIT_COLS;
181            if padding_size > 0 {
182                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
183            }
184        }
185
186        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
187        let values = unsafe {
188            core::slice::from_raw_parts_mut(buffer_ptr, nb_rows * NUM_MEMORY_LOCAL_INIT_COLS)
189        };
190
191        let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
192            .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
193            .collect::<Vec<_>>();
194
195        chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
196            rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
197                let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
198
199                let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
200                for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
201                    let cols = &mut cols.memory_local_entries[k];
202                    if idx + k < events.len() {
203                        let event = &events[idx + k];
204                        cols.addr = [
205                            F::from_canonical_u64(event.addr & 0xFFFF),
206                            F::from_canonical_u64((event.addr >> 16) & 0xFFFF),
207                            F::from_canonical_u64((event.addr >> 32) & 0xFFFF),
208                        ];
209                        cols.initial_clk_high = F::from_canonical_u32(
210                            (event.initial_mem_access.timestamp >> 24) as u32,
211                        );
212                        cols.final_clk_high =
213                            F::from_canonical_u32((event.final_mem_access.timestamp >> 24) as u32);
214                        cols.initial_clk_low = F::from_canonical_u32(
215                            (event.initial_mem_access.timestamp & 0xFFFFFF) as u32,
216                        );
217                        cols.final_clk_low = F::from_canonical_u32(
218                            (event.final_mem_access.timestamp & 0xFFFFFF) as u32,
219                        );
220                        cols.initial_value = event.initial_mem_access.value.into();
221                        cols.final_value = event.final_mem_access.value.into();
222                        cols.is_real = F::one();
223                        // split the third limb of initial value into 2 limbs of 8 bits
224                        let initial_value_byte0 = (event.initial_mem_access.value >> 32) & 0xFF;
225                        let initial_value_byte1 = (event.initial_mem_access.value >> 40) & 0xFF;
226                        cols.initial_value_lower =
227                            F::from_canonical_u32(initial_value_byte0 as u32);
228                        cols.initial_value_upper =
229                            F::from_canonical_u32(initial_value_byte1 as u32);
230                        let final_value_byte0 = (event.final_mem_access.value >> 32) & 0xFF;
231                        let final_value_byte1 = (event.final_mem_access.value >> 40) & 0xFF;
232                        cols.final_value_lower = F::from_canonical_u32(final_value_byte0 as u32);
233                        cols.final_value_upper = F::from_canonical_u32(final_value_byte1 as u32);
234                    }
235                }
236            });
237        });
238    }
239
240    fn included(&self, shard: &Self::Record) -> bool {
241        if let Some(shape) = shard.shape.as_ref() {
242            shape.included::<F, _>(self)
243        } else {
244            shard.get_local_mem_events().nth(0).is_some()
245        }
246    }
247
248    fn column_names(&self) -> Vec<String> {
249        MemoryLocalCols::<F>::struct_reflection().unwrap()
250    }
251}
252
253impl<AB> Air<AB> for MemoryLocalChip
254where
255    AB: SP1AirBuilder,
256{
257    fn eval(&self, builder: &mut AB) {
258        let main = builder.main();
259        let local = main.row_slice(0);
260        let local: &MemoryLocalCols<AB::Var> = (*local).borrow();
261
262        for local in local.memory_local_entries.iter() {
263            // Constrain that `local.is_real` is boolean.
264            builder.assert_bool(local.is_real);
265
266            builder.assert_eq(
267                local.is_real * local.is_real * local.is_real,
268                local.is_real * local.is_real * local.is_real,
269            );
270
271            // Constrain that value_lower and value_upper are the lower and upper byte of the limb.
272            builder.assert_eq(
273                local.initial_value.0[2],
274                local.initial_value_lower
275                    + local.initial_value_upper * AB::F::from_canonical_u32(1 << 8),
276            );
277            builder.slice_range_check_u8(
278                &[local.initial_value_lower, local.initial_value_upper],
279                local.is_real,
280            );
281            builder.slice_range_check_u16(&local.initial_value.0, local.is_real);
282
283            let mut values = vec![local.initial_clk_high.into(), local.initial_clk_low.into()];
284            values.extend(local.addr.map(Into::into));
285            values.extend(local.initial_value.map(Into::into));
286            builder.receive(
287                AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
288                InteractionScope::Local,
289            );
290
291            // Send the "receive interaction" to the global table.
292            builder.send(
293                AirInteraction::new(
294                    vec![
295                        local.initial_clk_high.into(),
296                        local.initial_clk_low.into(),
297                        local.addr[0].into(),
298                        local.addr[1].into(),
299                        local.addr[2].into(),
300                        local.initial_value.0[0]
301                            + local.initial_value_lower * AB::F::from_canonical_u32(1 << 16),
302                        local.initial_value.0[1]
303                            + local.initial_value_upper * AB::F::from_canonical_u32(1 << 16),
304                        local.initial_value.0[3].into(),
305                        AB::Expr::zero(),
306                        AB::Expr::one(),
307                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
308                    ],
309                    local.is_real.into(),
310                    InteractionKind::Global,
311                ),
312                InteractionScope::Local,
313            );
314
315            // Constrain that value_lower and value_upper are the lower and upper byte of the limb.
316            builder.assert_eq(
317                local.final_value.0[2],
318                local.final_value_lower
319                    + local.final_value_upper * AB::F::from_canonical_u32(1 << 8),
320            );
321            builder.slice_range_check_u8(
322                &[local.final_value_lower, local.final_value_upper],
323                local.is_real,
324            );
325            builder.slice_range_check_u16(&local.final_value.0, local.is_real);
326
327            let mut values = vec![local.final_clk_high.into(), local.final_clk_low.into()];
328            values.extend(local.addr.map(Into::into));
329            values.extend(local.final_value.map(Into::into));
330            builder.send(
331                AirInteraction::new(values.clone(), local.is_real.into(), InteractionKind::Memory),
332                InteractionScope::Local,
333            );
334
335            // Send the "send interaction" to the global table.
336            builder.send(
337                AirInteraction::new(
338                    vec![
339                        local.final_clk_high.into(),
340                        local.final_clk_low.into(),
341                        local.addr[0].into(),
342                        local.addr[1].into(),
343                        local.addr[2].into(),
344                        local.final_value.0[0]
345                            + local.final_value_lower * AB::F::from_canonical_u32(1 << 16),
346                        local.final_value.0[1]
347                            + local.final_value_upper * AB::F::from_canonical_u32(1 << 16),
348                        local.final_value.0[3].into(),
349                        AB::Expr::one(),
350                        AB::Expr::zero(),
351                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
352                    ],
353                    local.is_real.into(),
354                    InteractionKind::Global,
355                ),
356                InteractionScope::Local,
357            );
358        }
359    }
360}
361
362// #[cfg(test)]
363// mod tests {
364//     #![allow(clippy::print_stdout)]
365
366//     use crate::programs::tests::*;
367//     use crate::{
368//         memory::MemoryLocalChip, riscv::RiscvAir,
369//         syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
370//     };
371//     use sp1_primitives::SP1Field;
372//     use slop_matrix::dense::RowMajorMatrix;
373//     use sp1_core_executor::{ExecutionRecord, Executor, Trace};
374//     use sp1_hypercube::{
375//         air::{InteractionScope, MachineAir},
376//         koala_bear_poseidon2::SP1InnerPcs,
377//         debug_interactions_with_all_chips, InteractionKind, SP1CoreOpts, StarkMachine,
378//     };
379
380//     #[test]
381//     fn test_local_memory_generate_trace() {
382//         let program = simple_program();
383//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
384//         runtime.run::<Trace>().unwrap();
385//         let shard = runtime.records[0].clone();
386
387//         let chip: MemoryLocalChip = MemoryLocalChip::new();
388
389//         let trace: RowMajorMatrix<SP1Field> =
390//             chip.generate_trace(&shard, &mut ExecutionRecord::default());
391//         println!("{:?}", trace.values);
392
393//         for mem_event in shard.global_memory_finalize_events {
394//             println!("{mem_event:?}");
395//         }
396//     }
397
398//     #[test]
399//     fn test_memory_lookup_interactions() {
400//         setup_logger();
401//         let program = sha_extend_program();
402//         let program_clone = program.clone();
403//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
404//         runtime.run::<Trace>().unwrap();
405//         let machine: StarkMachine<SP1InnerPcs, RiscvAir<SP1Field>> =
406//             RiscvAir::machine(SP1InnerPcs::new());
407//         let (pkey, _) = machine.setup(&program_clone);
408//         let opts = SP1CoreOpts::default();
409//         machine.generate_dependencies(
410//             &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
411//             &opts,
412//             None,
413//         );
414
415//         let shards = runtime.records;
416//         for shard in shards.clone() {
417//             debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
418//                 &machine,
419//                 &pkey,
420//                 &[*shard],
421//                 vec![InteractionKind::Memory],
422//                 InteractionScope::Local,
423//             );
424//         }
425//         debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
426//             &machine,
427//             &pkey,
428//             &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
429//             vec![InteractionKind::Memory],
430//             InteractionScope::Global,
431//         );
432//     }
433
434//     #[test]
435//     fn test_byte_lookup_interactions() {
436//         setup_logger();
437//         let program = sha_extend_program();
438//         let program_clone = program.clone();
439//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
440//         runtime.run::<Trace>().unwrap();
441//         let machine = RiscvAir::machine(SP1InnerPcs::new());
442//         let (pkey, _) = machine.setup(&program_clone);
443//         let opts = SP1CoreOpts::default();
444//         machine.generate_dependencies(
445//             &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
446//             &opts,
447//             None,
448//         );
449
450//         let shards = runtime.records;
451//         for shard in shards.clone() {
452//             debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
453//                 &machine,
454//                 &pkey,
455//                 &[*shard],
456//                 vec![InteractionKind::Memory],
457//                 InteractionScope::Local,
458//             );
459//         }
460//         debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
461//             &machine,
462//             &pkey,
463//             &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
464//             vec![InteractionKind::Byte],
465//             InteractionScope::Global,
466//         );
467//     }
468
469//     #[cfg(feature = "sys")]
470//     fn get_test_execution_record() -> ExecutionRecord {
471//         use slop_algebra::PrimeField32;
472//         use rand::{thread_rng, Rng};
473//         use sp1_core_executor::events::{MemoryLocalEvent, MemoryRecord};
474
475//         let cpu_local_memory_access = (0..=255)
476//             .flat_map(|_| {
477//                 [{
478//                     let addr = thread_rng().gen_range(0..SP1Field::ORDER_U32);
479//                     let init_value = thread_rng().gen_range(0..u32::MAX);
480//                     let init_shard = thread_rng().gen_range(0..(1u32 << 16));
481//                     let init_timestamp = thread_rng().gen_range(0..(1u32 << 24));
482//                     let final_value = thread_rng().gen_range(0..u32::MAX);
483//                     let final_timestamp = thread_rng().gen_range(0..(1u32 << 24));
484//                     let final_shard = thread_rng().gen_range(0..(1u32 << 16));
485//                     MemoryLocalEvent {
486//                         addr,
487//                         initial_mem_access: MemoryRecord {
488//                             shard: init_shard,
489//                             timestamp: init_timestamp,
490//                             value: init_value,
491//                         },
492//                         final_mem_access: MemoryRecord {
493//                             shard: final_shard,
494//                             timestamp: final_timestamp,
495//                             value: final_value,
496//                         },
497//                     }
498//                 }]
499//             })
500//             .collect::<Vec<_>>();
501//         ExecutionRecord { cpu_local_memory_access, ..Default::default() }
502//     }
503
504//     #[cfg(feature = "sys")]
505//     #[test]
506//     fn test_generate_trace_ffi_eq_rust() {
507//         use slop_matrix::Matrix;
508
509//         let record = get_test_execution_record();
510//         let chip = MemoryLocalChip::new();
511//         let trace: RowMajorMatrix<SP1Field> =
512//             chip.generate_trace(&record, &mut ExecutionRecord::default());
513//         let trace_ffi = generate_trace_ffi(&record, trace.height());
514
515//         assert_eq!(trace_ffi, trace);
516//     }
517
518//     #[cfg(feature = "sys")]
519//     fn generate_trace_ffi(input: &ExecutionRecord, height: usize) -> RowMajorMatrix<SP1Field> {
520//         use std::borrow::BorrowMut;
521
522//         use rayon::iter::{IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator};
523
524//         use crate::{
525//             memory::{
526//                 MemoryLocalCols, NUM_LOCAL_MEMORY_ENTRIES_PER_ROW, NUM_MEMORY_LOCAL_INIT_COLS,
527//             },
528//             utils::zeroed_f_vec,
529//         };
530
531//         use sp1_primitives::SP1Field;
532// type F = SP1Field;
533//         // Generate the trace rows for each event.
534//         let events = input.get_local_mem_events().collect::<Vec<_>>();
535//         let nb_rows = events.len().div_ceil(4);
536//         let padded_nb_rows = height;
537//         let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_LOCAL_INIT_COLS);
538//         let chunk_size = std::cmp::max(nb_rows / num_cpus::get(), 0) + 1;
539
540//         let mut chunks = values[..nb_rows * NUM_MEMORY_LOCAL_INIT_COLS]
541//             .chunks_mut(chunk_size * NUM_MEMORY_LOCAL_INIT_COLS)
542//             .collect::<Vec<_>>();
543
544//         chunks.par_iter_mut().enumerate().for_each(|(i, rows)| {
545//             rows.chunks_mut(NUM_MEMORY_LOCAL_INIT_COLS).enumerate().for_each(|(j, row)| {
546//                 let idx = (i * chunk_size + j) * NUM_LOCAL_MEMORY_ENTRIES_PER_ROW;
547//                 let cols: &mut MemoryLocalCols<F> = row.borrow_mut();
548//                 for k in 0..NUM_LOCAL_MEMORY_ENTRIES_PER_ROW {
549//                     let cols = &mut cols.memory_local_entries[k];
550//                     if idx + k < events.len() {
551//                         unsafe {
552//                             crate::sys::memory_local_event_to_row_koalabear(events[idx + k],
553// cols);                         }
554//                     }
555//                 }
556//             });
557//         });
558
559//         // Convert the trace to a row major matrix.
560//         RowMajorMatrix::new(values, NUM_MEMORY_LOCAL_INIT_COLS)
561//     }
562// }