Skip to main content

sp1_core_machine/memory/
global.rs

1use super::MemoryChipType;
2use crate::{
3    air::{SP1CoreAirBuilder, SP1Operation, WordAirBuilder},
4    operations::{
5        IsZeroOperation, IsZeroOperationInput, LtOperationUnsigned, LtOperationUnsignedInput,
6    },
7    utils::next_multiple_of_32,
8};
9use core::{
10    borrow::{Borrow, BorrowMut},
11    mem::{size_of, MaybeUninit},
12};
13use slop_air::{Air, AirBuilder, BaseAir};
14use slop_algebra::{AbstractField, PrimeField32};
15use slop_matrix::Matrix;
16use slop_maybe_rayon::prelude::{
17    IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator, ParallelSlice,
18    ParallelSliceMut,
19};
20use sp1_core_executor::{
21    events::{ByteRecord, GlobalInteractionEvent, MemoryInitializeFinalizeEvent},
22    ExecutionRecord, Program,
23};
24use sp1_derive::AlignedBorrow;
25use sp1_hypercube::{
26    air::{AirInteraction, InteractionScope, MachineAir},
27    InteractionKind, Word,
28};
29use sp1_primitives::consts::u64_to_u16_limbs;
30use std::iter::once;
31use struct_reflection::{StructReflection, StructReflectionHelper};
32
33/// A memory chip that can initialize or finalize values in memory.
34pub struct MemoryGlobalChip {
35    pub kind: MemoryChipType,
36}
37
38impl MemoryGlobalChip {
39    /// Creates a new memory chip with a certain type.
40    pub const fn new(kind: MemoryChipType) -> Self {
41        Self { kind }
42    }
43}
44
45impl<F> BaseAir<F> for MemoryGlobalChip {
46    fn width(&self) -> usize {
47        NUM_MEMORY_INIT_COLS
48    }
49}
50
51impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
52    type Record = ExecutionRecord;
53
54    type Program = Program;
55
56    fn name(&self) -> &'static str {
57        match self.kind {
58            MemoryChipType::Initialize => "MemoryGlobalInit",
59            MemoryChipType::Finalize => "MemoryGlobalFinalize",
60        }
61    }
62
63    fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
64        let mut memory_events = match self.kind {
65            MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
66            MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
67        };
68
69        let is_receive = match self.kind {
70            MemoryChipType::Initialize => false,
71            MemoryChipType::Finalize => true,
72        };
73
74        match self.kind {
75            MemoryChipType::Initialize => {
76                output.public_values.global_init_count += memory_events.len() as u32;
77            }
78            MemoryChipType::Finalize => {
79                output.public_values.global_finalize_count += memory_events.len() as u32;
80            }
81        };
82
83        let previous_addr = match self.kind {
84            MemoryChipType::Initialize => input.public_values.previous_init_addr,
85            MemoryChipType::Finalize => input.public_values.previous_finalize_addr,
86        };
87
88        memory_events.sort_by_key(|event| event.addr);
89
90        let chunk_size = std::cmp::max(memory_events.len() / num_cpus::get(), 1);
91        let indices = (0..memory_events.len()).collect::<Vec<_>>();
92        let blu_batches = indices
93            .par_chunks(chunk_size)
94            .map(|chunk| {
95                let mut blu = Vec::new();
96                let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
97                let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
98                chunk.iter().for_each(|&i| {
99                    let addr = memory_events[i].addr;
100                    let value = memory_events[i].value;
101                    let prev_addr = if i == 0 { previous_addr } else { memory_events[i - 1].addr };
102                    blu.add_u16_range_checks(&u64_to_u16_limbs(value));
103                    blu.add_u16_range_checks(&u64_to_u16_limbs(prev_addr)[0..3]);
104                    blu.add_u16_range_checks(&u64_to_u16_limbs(addr)[0..3]);
105                    let value_lower = (value >> 32 & 0xFF) as u8;
106                    let value_upper = (value >> 40 & 0xFF) as u8;
107                    blu.add_u8_range_check(value_lower, value_upper);
108                    if i != 0 || prev_addr != 0 {
109                        cols.lt_cols.populate_unsigned(&mut blu, 1, prev_addr, addr);
110                    }
111                });
112                blu
113            })
114            .collect::<Vec<_>>();
115        output.add_byte_lookup_events(blu_batches.into_iter().flatten().collect());
116
117        let events = memory_events.into_iter().map(|event| {
118            let interaction_clk_high = if is_receive { (event.timestamp >> 24) as u32 } else { 0 };
119            let interaction_clk_low =
120                if is_receive { (event.timestamp & 0xFFFFFF) as u32 } else { 0 };
121            let limb_1 =
122                (event.value & 0xFFFF) as u32 + (1 << 16) * (event.value >> 32 & 0xFF) as u32;
123            let limb_2 =
124                (event.value >> 16 & 0xFFFF) as u32 + (1 << 16) * (event.value >> 40 & 0xFF) as u32;
125
126            GlobalInteractionEvent {
127                message: [
128                    interaction_clk_high,
129                    interaction_clk_low,
130                    (event.addr & 0xFFFF) as u32,
131                    ((event.addr >> 16) & 0xFFFF) as u32,
132                    ((event.addr >> 32) & 0xFFFF) as u32,
133                    limb_1,
134                    limb_2,
135                    ((event.value >> 48) & 0xFFFF) as u32,
136                ],
137                is_receive,
138                kind: InteractionKind::Memory as u8,
139            }
140        });
141        output.global_interaction_events.extend(events);
142    }
143
144    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
145        let events = match self.kind {
146            MemoryChipType::Initialize => &input.global_memory_initialize_events,
147            MemoryChipType::Finalize => &input.global_memory_finalize_events,
148        };
149        let nb_rows = events.len();
150        let size_log2 = input.fixed_log2_rows::<F, Self>(self);
151        let padded_nb_rows = next_multiple_of_32(nb_rows, size_log2);
152        Some(padded_nb_rows)
153    }
154
155    fn generate_trace_into(
156        &self,
157        input: &ExecutionRecord,
158        _output: &mut ExecutionRecord,
159        buffer: &mut [MaybeUninit<F>],
160    ) {
161        let mut memory_events = match self.kind {
162            MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
163            MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
164        };
165
166        let previous_addr = match self.kind {
167            MemoryChipType::Initialize => input.public_values.previous_init_addr,
168            MemoryChipType::Finalize => input.public_values.previous_finalize_addr,
169        };
170
171        memory_events.sort_by_key(|event| event.addr);
172
173        let padded_nb_rows = <MemoryGlobalChip as MachineAir<F>>::num_rows(self, input).unwrap();
174        let num_event_rows = memory_events.len();
175
176        unsafe {
177            let padding_start = num_event_rows * NUM_MEMORY_INIT_COLS;
178            let padding_size = (padded_nb_rows - num_event_rows) * NUM_MEMORY_INIT_COLS;
179            if padding_size > 0 {
180                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
181            }
182        }
183
184        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
185        let values = unsafe {
186            core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_MEMORY_INIT_COLS)
187        };
188
189        values.par_chunks_exact_mut(NUM_MEMORY_INIT_COLS).zip(memory_events.par_iter()).for_each(
190            |(row, event)| {
191                let cols: &mut MemoryInitCols<F> = row.borrow_mut();
192                let MemoryInitializeFinalizeEvent { addr, value, timestamp } = event.to_owned();
193
194                cols.addr[0] = F::from_canonical_u16((addr & 0xFFFF) as u16);
195                cols.addr[1] = F::from_canonical_u16(((addr >> 16) & 0xFFFF) as u16);
196                cols.addr[2] = F::from_canonical_u16(((addr >> 32) & 0xFFFF) as u16);
197                cols.clk_high = F::from_canonical_u32((timestamp >> 24) as u32);
198                cols.clk_low = F::from_canonical_u32((timestamp & 0xFFFFFF) as u32);
199                cols.value = Word::from(value);
200                cols.is_real = F::one();
201                cols.value_lower = F::from_canonical_u32((value >> 32 & 0xFF) as u32);
202                cols.value_upper = F::from_canonical_u32((value >> 40 & 0xFF) as u32);
203            },
204        );
205
206        let mut blu = vec![];
207        for i in 0..memory_events.len() {
208            let row_start = i * NUM_MEMORY_INIT_COLS;
209            let row = &mut values[row_start..row_start + NUM_MEMORY_INIT_COLS];
210            let cols: &mut MemoryInitCols<F> = row.borrow_mut();
211
212            let addr = memory_events[i].addr;
213            let prev_addr = if i == 0 { previous_addr } else { memory_events[i - 1].addr };
214
215            if prev_addr == 0 && i != 0 {
216                cols.prev_valid = F::zero();
217            } else {
218                cols.prev_valid = F::one();
219            }
220            cols.index = F::from_canonical_u32(i as u32);
221            cols.prev_addr[0] = F::from_canonical_u16((prev_addr & 0xFFFF) as u16);
222            cols.prev_addr[1] = F::from_canonical_u16(((prev_addr >> 16) & 0xFFFF) as u16);
223            cols.prev_addr[2] = F::from_canonical_u16(((prev_addr >> 32) & 0xFFFF) as u16);
224            cols.is_prev_addr_zero.populate_from_field_element(
225                cols.prev_addr[0] + cols.prev_addr[1] + cols.prev_addr[2],
226            );
227            cols.is_index_zero.populate(i as u64);
228            if prev_addr != 0 || i != 0 {
229                cols.is_comp = F::one();
230                cols.lt_cols.populate_unsigned(&mut blu, 1, prev_addr, addr);
231            } else {
232                cols.is_comp = F::zero();
233                cols.lt_cols = LtOperationUnsigned::<F>::default();
234            }
235        }
236    }
237
238    fn included(&self, shard: &Self::Record) -> bool {
239        if let Some(shape) = shard.shape.as_ref() {
240            shape.included::<F, _>(self)
241        } else {
242            match self.kind {
243                MemoryChipType::Initialize => !shard.global_memory_initialize_events.is_empty(),
244                MemoryChipType::Finalize => !shard.global_memory_finalize_events.is_empty(),
245            }
246        }
247    }
248
249    fn column_names(&self) -> Vec<String> {
250        MemoryInitCols::<F>::struct_reflection().unwrap()
251    }
252}
253
254#[derive(AlignedBorrow, Clone, Copy, StructReflection)]
255#[repr(C)]
256pub struct MemoryInitCols<T: Copy> {
257    /// The top bits of the timestamp of the memory access.
258    pub clk_high: T,
259
260    /// The low bits of the timestamp of the memory access.
261    pub clk_low: T,
262
263    /// The index of the memory access.
264    pub index: T,
265
266    /// The address of the previous memory access.
267    pub prev_addr: [T; 3],
268
269    /// The address of the memory access.
270    pub addr: [T; 3],
271
272    /// Comparison assertions for address to be strictly increasing.
273    pub lt_cols: LtOperationUnsigned<T>,
274
275    /// The value of the memory access.
276    pub value: Word<T>,
277
278    /// Lower half of third limb of the value
279    pub value_lower: T,
280
281    /// Upper half of third limb of the value
282    pub value_upper: T,
283
284    /// Whether the memory access is a real access.
285    pub is_real: T,
286
287    /// Whether or not we are making the assertion `prev_addr < addr`.
288    pub is_comp: T,
289
290    /// The validity of previous state.
291    /// The unique invalid state is when the chip only initializes address 0 once.
292    pub prev_valid: T,
293
294    /// A witness to assert whether or not `prev_addr` is zero.
295    pub is_prev_addr_zero: IsZeroOperation<T>,
296
297    /// A witness to assert whether or not the index is zero.
298    pub is_index_zero: IsZeroOperation<T>,
299}
300
301pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();
302
303impl<AB> Air<AB> for MemoryGlobalChip
304where
305    AB: SP1CoreAirBuilder,
306{
307    fn eval(&self, builder: &mut AB) {
308        let main = builder.main();
309        let local = main.row_slice(0);
310        let local: &MemoryInitCols<AB::Var> = (*local).borrow();
311
312        // Constrain that `local.is_real` is boolean.
313        builder.assert_bool(local.is_real);
314        // Constrain that the value is a valid `Word`.
315        builder.slice_range_check_u16(&local.value.0, local.is_real);
316        // Constrain that the previous address is a valid `Word`.
317        builder.slice_range_check_u16(&local.prev_addr, local.is_real);
318        // Constrain that the address is a valid `Word`.
319        builder.slice_range_check_u16(&local.addr, local.is_real);
320
321        // Assert that value_lower and value_upper are the lower and upper halves of the third limb.
322        builder.assert_eq(
323            local.value.0[2],
324            local.value_lower + local.value_upper * AB::F::from_canonical_u32(1 << 8),
325        );
326        builder.slice_range_check_u8(&[local.value_lower, local.value_upper], local.is_real);
327
328        let interaction_kind = match self.kind {
329            MemoryChipType::Initialize => InteractionKind::MemoryGlobalInitControl,
330            MemoryChipType::Finalize => InteractionKind::MemoryGlobalFinalizeControl,
331        };
332
333        // Receive the previous index, address, and validity state.
334        builder.receive(
335            AirInteraction::new(
336                vec![local.index]
337                    .into_iter()
338                    .chain(local.prev_addr)
339                    .chain(once(local.prev_valid))
340                    .map(Into::into)
341                    .collect(),
342                local.is_real.into(),
343                interaction_kind,
344            ),
345            InteractionScope::Local,
346        );
347
348        // Send the next index, address, and validity state.
349        builder.send(
350            AirInteraction::new(
351                vec![local.index + AB::Expr::one()]
352                    .into_iter()
353                    .chain(local.addr.map(Into::into))
354                    .chain(once(local.is_comp.into()))
355                    .collect(),
356                local.is_real.into(),
357                interaction_kind,
358            ),
359            InteractionScope::Local,
360        );
361
362        if self.kind == MemoryChipType::Initialize {
363            // Send the "send interaction" to the global table.
364            builder.send(
365                AirInteraction::new(
366                    vec![
367                        AB::Expr::zero(),
368                        AB::Expr::zero(),
369                        local.addr[0].into(),
370                        local.addr[1].into(),
371                        local.addr[2].into(),
372                        local.value.0[0] + local.value_lower * AB::F::from_canonical_u32(1 << 16),
373                        local.value.0[1] + local.value_upper * AB::F::from_canonical_u32(1 << 16),
374                        local.value.0[3].into(),
375                        AB::Expr::one(),
376                        AB::Expr::zero(),
377                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
378                    ],
379                    local.is_real.into(),
380                    InteractionKind::Global,
381                ),
382                InteractionScope::Local,
383            );
384        } else {
385            // Send the "receive interaction" to the global table.
386            builder.send(
387                AirInteraction::new(
388                    vec![
389                        local.clk_high.into(),
390                        local.clk_low.into(),
391                        local.addr[0].into(),
392                        local.addr[1].into(),
393                        local.addr[2].into(),
394                        local.value.0[0] + local.value_lower * AB::F::from_canonical_u32(1 << 16),
395                        local.value.0[1] + local.value_upper * AB::F::from_canonical_u32(1 << 16),
396                        local.value.0[3].into(),
397                        AB::Expr::zero(),
398                        AB::Expr::one(),
399                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
400                    ],
401                    local.is_real.into(),
402                    InteractionKind::Global,
403                ),
404                InteractionScope::Local,
405            );
406        }
407
408        // Assert that `prev_addr < addr` when `prev_addr != 0` or `index != 0`.
409        // First, check if `prev_addr != 0`, and check if `index != 0`.
410        // SAFETY: Since `prev_addr` are composed of valid u16 limbs, adding them to check if
411        // all three limbs are zero is safe, as overflows are impossible.
412        IsZeroOperation::<AB::F>::eval(
413            builder,
414            IsZeroOperationInput::new(
415                local.prev_addr[0] + local.prev_addr[1] + local.prev_addr[2],
416                local.is_prev_addr_zero,
417                local.is_real.into(),
418            ),
419        );
420        IsZeroOperation::<AB::F>::eval(
421            builder,
422            IsZeroOperationInput::new(
423                local.index.into(),
424                local.is_index_zero,
425                local.is_real.into(),
426            ),
427        );
428
429        // Comparison will be done unless both `prev_addr == 0` and `index == 0`.
430        // If `is_real = 0`, then `is_comp` will be zero.
431        // If `is_real = 1`, then `is_comp` will be zero when `prev_addr == 0` and `index == 0`.
432        // If `is_real = 1`, then `is_comp` will be one when `prev_addr != 0` or `index != 0`.
433        builder.assert_eq(
434            local.is_comp,
435            local.is_real
436                * (AB::Expr::one() - local.is_prev_addr_zero.result * local.is_index_zero.result),
437        );
438        builder.assert_bool(local.is_comp);
439
440        // If `is_comp = 1`, then `prev_addr < addr` should hold.
441        <LtOperationUnsigned<AB::F> as SP1Operation<AB>>::eval(
442            builder,
443            LtOperationUnsignedInput::<AB>::new(
444                Word([
445                    local.prev_addr[0].into(),
446                    local.prev_addr[1].into(),
447                    local.prev_addr[2].into(),
448                    AB::Expr::zero(),
449                ]),
450                Word([
451                    local.addr[0].into(),
452                    local.addr[1].into(),
453                    local.addr[2].into(),
454                    AB::Expr::zero(),
455                ]),
456                local.lt_cols,
457                local.is_comp.into(),
458            ),
459        );
460        builder.when(local.is_comp).assert_one(local.lt_cols.u16_compare_operation.bit);
461
462        // If `prev_addr == 0` and `index == 0`, then `addr == 0`, and the `value` should be zero.
463        // SAFETY: Since `local.addr` is valid u16 limbs, one can constrain that the sum of the
464        // limbs is zero in order to constrain that `addr == 0`, as no overflow is possible.
465        // This forces the initialization of address 0 with value 0.
466        // Constraints related to register %x0: Register %x0 should always be 0.
467        // See 2.6 Load and Store Instruction on P.18 of the RISC-V spec.
468        let is_not_comp = local.is_real - local.is_comp;
469        builder
470            .when(is_not_comp.clone())
471            .assert_zero(local.addr[0] + local.addr[1] + local.addr[2]);
472        builder.when(is_not_comp.clone()).assert_word_zero(local.value);
473    }
474}
475
476// #[cfg(test)]
477// mod tests {
478//     #![allow(clippy::print_stdout)]
479
480//     use super::*;
481//     use crate::programs::tests::*;
482//     use crate::{
483//         riscv::RiscvAir, syscall::precompiles::sha256::extend_tests::sha_extend_program,
484//         utils::setup_logger,
485//     };
486//     use sp1_primitives::SP1Field;
487//     use sp1_core_executor::{Executor, Trace};
488//     use sp1_hypercube::InteractionKind;
489//     use sp1_hypercube::{
490//         koala_bear_poseidon2::SP1InnerPcs, debug_interactions_with_all_chips,
491// SP1CoreOpts,         StarkMachine,
492//     };
493
494//     #[test]
495//     fn test_memory_generate_trace() {
496//         let program = simple_program();
497//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
498//         runtime.run::<Trace>().unwrap();
499//         let shard = runtime.record.clone();
500
501//         let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Initialize);
502
503//         let trace: RowMajorMatrix<SP1Field> =
504//             chip.generate_trace(&shard, &mut ExecutionRecord::default());
505//         println!("{:?}", trace.values);
506
507//         let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Finalize);
508//         let trace: RowMajorMatrix<SP1Field> =
509//             chip.generate_trace(&shard, &mut ExecutionRecord::default());
510//         println!("{:?}", trace.values);
511
512//         for mem_event in shard.global_memory_finalize_events {
513//             println!("{:?}", mem_event);
514//         }
515//     }
516
517//     #[test]
518//     fn test_memory_lookup_interactions() {
519//         setup_logger();
520//         let program = sha_extend_program();
521//         let program_clone = program.clone();
522//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
523//         runtime.run::<Trace>().unwrap();
524//         let machine: StarkMachine<SP1InnerPcs, RiscvAir<SP1Field>> =
525//             RiscvAir::machine(SP1InnerPcs::new());
526//         let (pkey, _) = machine.setup(&program_clone);
527//         let opts = SP1CoreOpts::default();
528//         machine.generate_dependencies(
529//             &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
530//             &opts,
531//             None,
532//         );
533
534//         let shards = runtime.records;
535//         for shard in shards.clone() {
536//             debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
537//                 &machine,
538//                 &pkey,
539//                 &[*shard],
540//                 vec![InteractionKind::Memory],
541//                 InteractionScope::Local,
542//             );
543//         }
544//         debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
545//             &machine,
546//             &pkey,
547//             &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
548//             vec![InteractionKind::Memory],
549//             InteractionScope::Global,
550//         );
551//     }
552
553//     #[test]
554//     fn test_byte_lookup_interactions() {
555//         setup_logger();
556//         let program = sha_extend_program();
557//         let program_clone = program.clone();
558//         let mut runtime = Executor::new(program, SP1CoreOpts::default());
559//         runtime.run::<Trace>().unwrap();
560//         let machine = RiscvAir::machine(SP1InnerPcs::new());
561//         let (pkey, _) = machine.setup(&program_clone);
562//         let opts = SP1CoreOpts::default();
563//         machine.generate_dependencies(
564//             &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
565//             &opts,
566//             None,
567//         );
568
569//         let shards = runtime.records;
570//         debug_interactions_with_all_chips::<SP1InnerPcs, RiscvAir<SP1Field>>(
571//             &machine,
572//             &pkey,
573//             &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
574//             vec![InteractionKind::Byte],
575//             InteractionScope::Global,
576//         );
577//     }
578// }