sp1_core_machine/memory/
global.rs

1use super::MemoryChipType;
2use crate::{
3    operations::{AssertLtColsBits, BabyBearBitDecomposition, IsZeroOperation},
4    utils::next_power_of_two,
5};
6use core::{
7    borrow::{Borrow, BorrowMut},
8    mem::size_of,
9};
10
11use p3_air::{Air, AirBuilder, BaseAir};
12use p3_field::{AbstractField, PrimeField32};
13use p3_matrix::{dense::RowMajorMatrix, Matrix};
14use p3_maybe_rayon::prelude::{IntoParallelRefIterator, ParallelIterator};
15use sp1_core_executor::{
16    events::{GlobalInteractionEvent, MemoryInitializeFinalizeEvent},
17    ExecutionRecord, Program,
18};
19use sp1_derive::AlignedBorrow;
20use sp1_stark::{
21    air::{
22        AirInteraction, BaseAirBuilder, InteractionScope, MachineAir, PublicValues, SP1AirBuilder,
23        SP1_PROOF_NUM_PV_ELTS,
24    },
25    InteractionKind, Word,
26};
27use std::array;
28
29/// A memory chip that can initialize or finalize values in memory.
30pub struct MemoryGlobalChip {
31    pub kind: MemoryChipType,
32}
33
34impl MemoryGlobalChip {
35    /// Creates a new memory chip with a certain type.
36    pub const fn new(kind: MemoryChipType) -> Self {
37        Self { kind }
38    }
39}
40
41impl<F> BaseAir<F> for MemoryGlobalChip {
42    fn width(&self) -> usize {
43        NUM_MEMORY_INIT_COLS
44    }
45}
46
47impl<F: PrimeField32> MachineAir<F> for MemoryGlobalChip {
48    type Record = ExecutionRecord;
49
50    type Program = Program;
51
52    fn name(&self) -> String {
53        match self.kind {
54            MemoryChipType::Initialize => "MemoryGlobalInit".to_string(),
55            MemoryChipType::Finalize => "MemoryGlobalFinalize".to_string(),
56        }
57    }
58
59    fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
60        let mut memory_events = match self.kind {
61            MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
62            MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
63        };
64
65        let is_receive = match self.kind {
66            MemoryChipType::Initialize => false,
67            MemoryChipType::Finalize => true,
68        };
69
70        memory_events.sort_by_key(|event| event.addr);
71
72        let events = memory_events.into_iter().map(|event| {
73            let interaction_shard = if is_receive { event.shard } else { 0 };
74            let interaction_clk = if is_receive { event.timestamp } else { 0 };
75            GlobalInteractionEvent {
76                message: [
77                    interaction_shard,
78                    interaction_clk,
79                    event.addr,
80                    (event.value & 255) as u32,
81                    ((event.value >> 8) & 255) as u32,
82                    ((event.value >> 16) & 255) as u32,
83                    ((event.value >> 24) & 255) as u32,
84                ],
85                is_receive,
86                kind: InteractionKind::Memory as u8,
87            }
88        });
89        output.global_interaction_events.extend(events);
90    }
91
92    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
93        let events = match self.kind {
94            MemoryChipType::Initialize => &input.global_memory_initialize_events,
95            MemoryChipType::Finalize => &input.global_memory_finalize_events,
96        };
97        let nb_rows = events.len();
98        let size_log2 = input.fixed_log2_rows::<F, Self>(self);
99        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
100        Some(padded_nb_rows)
101    }
102
103    fn generate_trace(
104        &self,
105        input: &ExecutionRecord,
106        _output: &mut ExecutionRecord,
107    ) -> RowMajorMatrix<F> {
108        let mut memory_events = match self.kind {
109            MemoryChipType::Initialize => input.global_memory_initialize_events.clone(),
110            MemoryChipType::Finalize => input.global_memory_finalize_events.clone(),
111        };
112
113        let previous_addr_bits = match self.kind {
114            MemoryChipType::Initialize => input.public_values.previous_init_addr_bits,
115            MemoryChipType::Finalize => input.public_values.previous_finalize_addr_bits,
116        };
117
118        memory_events.sort_by_key(|event| event.addr);
119        let mut rows: Vec<[F; NUM_MEMORY_INIT_COLS]> = memory_events
120            .par_iter()
121            .map(|event| {
122                let MemoryInitializeFinalizeEvent { addr, value, shard, timestamp } =
123                    event.to_owned();
124
125                let mut row = [F::zero(); NUM_MEMORY_INIT_COLS];
126                let cols: &mut MemoryInitCols<F> = row.as_mut_slice().borrow_mut();
127                cols.addr = F::from_canonical_u32(addr);
128                cols.addr_bits.populate(addr);
129                cols.shard = F::from_canonical_u32(shard);
130                cols.timestamp = F::from_canonical_u32(timestamp);
131                cols.value = array::from_fn(|i| F::from_canonical_u32((value >> i) & 1));
132                cols.is_real = F::one();
133
134                row
135            })
136            .collect::<Vec<_>>();
137
138        for i in 0..memory_events.len() {
139            let addr = memory_events[i].addr;
140            let cols: &mut MemoryInitCols<F> = rows[i].as_mut_slice().borrow_mut();
141            if i == 0 {
142                let prev_addr = previous_addr_bits
143                    .iter()
144                    .enumerate()
145                    .map(|(j, bit)| bit * (1 << j))
146                    .sum::<u32>();
147                cols.is_prev_addr_zero.populate(prev_addr);
148                cols.is_first_comp = F::from_bool(prev_addr != 0);
149                if prev_addr != 0 {
150                    debug_assert!(prev_addr < addr, "prev_addr {prev_addr} < addr {addr}");
151                    let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
152                    cols.lt_cols.populate(&previous_addr_bits, &addr_bits);
153                }
154            }
155            if i != 0 {
156                cols.is_next_comp = F::one();
157                let previous_addr = memory_events[i - 1].addr;
158                assert_ne!(previous_addr, addr);
159
160                let addr_bits: [_; 32] = array::from_fn(|i| (addr >> i) & 1);
161                let prev_addr_bits: [_; 32] = array::from_fn(|i| (previous_addr >> i) & 1);
162                cols.lt_cols.populate(&prev_addr_bits, &addr_bits);
163            }
164
165            if i == memory_events.len() - 1 {
166                cols.is_last_addr = F::one();
167            }
168        }
169
170        // Pad the trace to a power of two depending on the proof shape in `input`.
171        rows.resize(
172            <MemoryGlobalChip as MachineAir<F>>::num_rows(self, input).unwrap(),
173            [F::zero(); NUM_MEMORY_INIT_COLS],
174        );
175
176        RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEMORY_INIT_COLS)
177    }
178
179    fn included(&self, shard: &Self::Record) -> bool {
180        if let Some(shape) = shard.shape.as_ref() {
181            shape.included::<F, _>(self)
182        } else {
183            match self.kind {
184                MemoryChipType::Initialize => !shard.global_memory_initialize_events.is_empty(),
185                MemoryChipType::Finalize => !shard.global_memory_finalize_events.is_empty(),
186            }
187        }
188    }
189
190    fn commit_scope(&self) -> InteractionScope {
191        InteractionScope::Local
192    }
193}
194
195#[derive(AlignedBorrow, Clone, Copy)]
196#[repr(C)]
197pub struct MemoryInitCols<T: Copy> {
198    /// The shard number of the memory access.
199    pub shard: T,
200
201    /// The timestamp of the memory access.
202    pub timestamp: T,
203
204    /// The address of the memory access.
205    pub addr: T,
206
207    /// Comparison assertions for address to be strictly increasing.
208    pub lt_cols: AssertLtColsBits<T, 32>,
209
210    /// A bit decomposition of `addr`.
211    pub addr_bits: BabyBearBitDecomposition<T>,
212
213    /// The value of the memory access.
214    pub value: [T; 32],
215
216    /// Whether the memory access is a real access.
217    pub is_real: T,
218
219    /// Whether or not we are making the assertion `addr < addr_next`.
220    pub is_next_comp: T,
221
222    /// A witness to assert whether or not we the previous address is zero.
223    pub is_prev_addr_zero: IsZeroOperation<T>,
224
225    /// Auxiliary column, equal to `(1 - is_prev_addr_zero.result) * is_first_row`.
226    pub is_first_comp: T,
227
228    /// A flag to indicate the last non-padded address. An auxiliary column needed for degree 3.
229    pub is_last_addr: T,
230}
231
232pub(crate) const NUM_MEMORY_INIT_COLS: usize = size_of::<MemoryInitCols<u8>>();
233
234impl<AB> Air<AB> for MemoryGlobalChip
235where
236    AB: SP1AirBuilder,
237{
238    fn eval(&self, builder: &mut AB) {
239        let main = builder.main();
240        let local = main.row_slice(0);
241        let local: &MemoryInitCols<AB::Var> = (*local).borrow();
242        let next = main.row_slice(1);
243        let next: &MemoryInitCols<AB::Var> = (*next).borrow();
244
245        // Constrain that `local.is_real` is boolean.
246        builder.assert_bool(local.is_real);
247        for i in 0..32 {
248            builder.assert_bool(local.value[i]);
249        }
250
251        let mut byte1 = AB::Expr::zero();
252        let mut byte2 = AB::Expr::zero();
253        let mut byte3 = AB::Expr::zero();
254        let mut byte4 = AB::Expr::zero();
255        for i in 0..8 {
256            byte1 = byte1.clone() + local.value[i].into() * AB::F::from_canonical_u8(1 << i);
257            byte2 = byte2.clone() + local.value[i + 8].into() * AB::F::from_canonical_u8(1 << i);
258            byte3 = byte3.clone() + local.value[i + 16].into() * AB::F::from_canonical_u8(1 << i);
259            byte4 = byte4.clone() + local.value[i + 24].into() * AB::F::from_canonical_u8(1 << i);
260        }
261        let value = [byte1, byte2, byte3, byte4];
262
263        if self.kind == MemoryChipType::Initialize {
264            // Send the "send interaction" to the global table.
265            builder.send(
266                AirInteraction::new(
267                    vec![
268                        AB::Expr::zero(),
269                        AB::Expr::zero(),
270                        local.addr.into(),
271                        value[0].clone(),
272                        value[1].clone(),
273                        value[2].clone(),
274                        value[3].clone(),
275                        AB::Expr::one(),
276                        AB::Expr::zero(),
277                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
278                    ],
279                    local.is_real.into(),
280                    InteractionKind::Global,
281                ),
282                InteractionScope::Local,
283            );
284        } else {
285            // Send the "receive interaction" to the global table.
286            builder.send(
287                AirInteraction::new(
288                    vec![
289                        local.shard.into(),
290                        local.timestamp.into(),
291                        local.addr.into(),
292                        value[0].clone(),
293                        value[1].clone(),
294                        value[2].clone(),
295                        value[3].clone(),
296                        AB::Expr::zero(),
297                        AB::Expr::one(),
298                        AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
299                    ],
300                    local.is_real.into(),
301                    InteractionKind::Global,
302                ),
303                InteractionScope::Local,
304            );
305        }
306
307        // Canonically decompose the address into bits so we can do comparisons.
308        BabyBearBitDecomposition::<AB::F>::range_check(
309            builder,
310            local.addr,
311            local.addr_bits,
312            local.is_real.into(),
313        );
314
315        // Assertion for increasing address. We need to make two types of less-than assertions,
316        // first we need to assert that the addr < addr' when the next row is real. Then we need to
317        // make assertions with regards to public values.
318        //
319        // If the chip is a `MemoryInit`:
320        // - In the first row, we need to assert that previous_init_addr < addr.
321        // - In the last real row, we need to assert that addr = last_init_addr.
322        //
323        // If the chip is a `MemoryFinalize`:
324        // - In the first row, we need to assert that previous_finalize_addr < addr.
325        // - In the last real row, we need to assert that addr = last_finalize_addr.
326
327        // Assert that addr < addr' when the next row is real.
328        builder.when_transition().assert_eq(next.is_next_comp, next.is_real);
329        next.lt_cols.eval(builder, &local.addr_bits.bits, &next.addr_bits.bits, next.is_next_comp);
330
331        // Assert that the real rows are all padded to the top.
332        builder.when_transition().when_not(local.is_real).assert_zero(next.is_real);
333
334        // Make assertions for the initial comparison.
335
336        // We want to constrain that the `adrr` in the first row is larger than the previous
337        // initialized/finalized address, unless the previous address is zero. Since the previous
338        // address is either zero or constrained by a different shard, we know it's an element of
339        // the field, so we can get an element from the bit decomposition with no concern for
340        // overflow.
341
342        let local_addr_bits = local.addr_bits.bits;
343
344        let public_values_array: [AB::Expr; SP1_PROOF_NUM_PV_ELTS] =
345            array::from_fn(|i| builder.public_values()[i].into());
346        let public_values: &PublicValues<Word<AB::Expr>, AB::Expr> =
347            public_values_array.as_slice().borrow();
348
349        let prev_addr_bits = match self.kind {
350            MemoryChipType::Initialize => &public_values.previous_init_addr_bits,
351            MemoryChipType::Finalize => &public_values.previous_finalize_addr_bits,
352        };
353
354        // Since the previous address is either zero or constrained by a different shard, we know
355        // it's an element of the field, so we can get an element from the bit decomposition with
356        // no concern for overflow.
357        let prev_addr = prev_addr_bits
358            .iter()
359            .enumerate()
360            .map(|(i, bit)| bit.clone() * AB::F::from_wrapped_u32(1 << i))
361            .sum::<AB::Expr>();
362
363        // Constrain the is_prev_addr_zero operation only in the first row.
364        let is_first_row = builder.is_first_row();
365        IsZeroOperation::<AB::F>::eval(builder, prev_addr, local.is_prev_addr_zero, is_first_row);
366
367        // Constrain the is_first_comp column.
368        builder.assert_bool(local.is_first_comp);
369        builder
370            .when_first_row()
371            .assert_eq(local.is_first_comp, AB::Expr::one() - local.is_prev_addr_zero.result);
372
373        // Ensure at least one real row.
374        builder.when_first_row().assert_one(local.is_real);
375
376        // Constrain the inequality assertion in the first row.
377        local.lt_cols.eval(builder, prev_addr_bits, &local_addr_bits, local.is_first_comp);
378
379        // Insure that there are no duplicate initializations by assuring there is exactly one
380        // initialization event of the zero address. This is done by assuring that when the previous
381        // address is zero, then the first row address is also zero, and that the second row is also
382        // real, and the less than comparison is being made.
383        builder.when_first_row().when(local.is_prev_addr_zero.result).assert_zero(local.addr);
384        builder.when_first_row().when(local.is_prev_addr_zero.result).assert_one(next.is_real);
385        // Ensure that in the address zero case the comparison is being made so that there is an
386        // address bigger than zero being committed to.
387        builder.when_first_row().when(local.is_prev_addr_zero.result).assert_one(next.is_next_comp);
388
389        // Constraints related to register %x0.
390
391        // Register %x0 should always be 0. See 2.6 Load and Store Instruction on
392        // P.18 of the RISC-V spec.  To ensure that, we will constrain that the value is zero
393        // whenever the `is_first_comp` flag is set to to zero as well. This guarantees that the
394        // presence of this flag asserts the initialization/finalization of %x0 to zero.
395        //
396        // **Remark**: it is up to the verifier to ensure that this flag is set to zero exactly
397        // once, this can be constrained by the public values setting `previous_init_addr_bits` or
398        // `previous_finalize_addr_bits` to zero.
399        for i in 0..32 {
400            builder.when_first_row().when_not(local.is_first_comp).assert_zero(local.value[i]);
401        }
402
403        // Make assertions for the final value. We need to connect the final valid address to the
404        // corresponding `last_addr` value.
405        let last_addr_bits = match self.kind {
406            MemoryChipType::Initialize => &public_values.last_init_addr_bits,
407            MemoryChipType::Finalize => &public_values.last_finalize_addr_bits,
408        };
409        // The last address is either:
410        // - It's the last row and `is_real` is set to one.
411        // - The flag `is_real` is set to one and the next `is_real` is set to zero.
412
413        // Constrain the `is_last_addr` flag.
414        builder
415            .when_transition()
416            .assert_eq(local.is_last_addr, local.is_real * (AB::Expr::one() - next.is_real));
417
418        // Constrain the last address bits to be equal to the corresponding `last_addr_bits` value.
419        for (local_bit, pub_bit) in local.addr_bits.bits.iter().zip(last_addr_bits.iter()) {
420            builder.when_last_row().when(local.is_real).assert_eq(*local_bit, pub_bit.clone());
421            builder
422                .when_transition()
423                .when(local.is_last_addr)
424                .assert_eq(*local_bit, pub_bit.clone());
425        }
426    }
427}
428
429#[cfg(test)]
430mod tests {
431    #![allow(clippy::print_stdout)]
432
433    use super::*;
434    use crate::{
435        programs::tests::*, riscv::RiscvAir,
436        syscall::precompiles::sha256::extend_tests::sha_extend_program, utils::setup_logger,
437    };
438    use p3_baby_bear::BabyBear;
439    use sp1_core_executor::Executor;
440    use sp1_stark::{
441        baby_bear_poseidon2::BabyBearPoseidon2, debug_interactions_with_all_chips, InteractionKind,
442        SP1CoreOpts, StarkMachine,
443    };
444
445    #[test]
446    fn test_memory_generate_trace() {
447        let program = simple_program();
448        let mut runtime = Executor::new(program, SP1CoreOpts::default());
449        runtime.run().unwrap();
450        let shard = runtime.record.clone();
451
452        let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Initialize);
453
454        let trace: RowMajorMatrix<BabyBear> =
455            chip.generate_trace(&shard, &mut ExecutionRecord::default());
456        println!("{:?}", trace.values);
457
458        let chip: MemoryGlobalChip = MemoryGlobalChip::new(MemoryChipType::Finalize);
459        let trace: RowMajorMatrix<BabyBear> =
460            chip.generate_trace(&shard, &mut ExecutionRecord::default());
461        println!("{:?}", trace.values);
462
463        for mem_event in shard.global_memory_finalize_events {
464            println!("{mem_event:?}");
465        }
466    }
467
468    #[test]
469    fn test_memory_lookup_interactions() {
470        setup_logger();
471        let program = sha_extend_program();
472        let program_clone = program.clone();
473        let mut runtime = Executor::new(program, SP1CoreOpts::default());
474        runtime.run().unwrap();
475        let machine: StarkMachine<BabyBearPoseidon2, RiscvAir<BabyBear>> =
476            RiscvAir::machine(BabyBearPoseidon2::new());
477        let (pkey, _) = machine.setup(&program_clone);
478        let opts = SP1CoreOpts::default();
479        machine.generate_dependencies(
480            &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
481            &opts,
482            None,
483        );
484
485        let shards = runtime.records;
486        for shard in shards.clone() {
487            debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
488                &machine,
489                &pkey,
490                &[*shard],
491                vec![InteractionKind::Memory],
492                InteractionScope::Local,
493            );
494        }
495        debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
496            &machine,
497            &pkey,
498            &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
499            vec![InteractionKind::Memory],
500            InteractionScope::Global,
501        );
502    }
503
504    #[test]
505    fn test_byte_lookup_interactions() {
506        setup_logger();
507        let program = sha_extend_program();
508        let program_clone = program.clone();
509        let mut runtime = Executor::new(program, SP1CoreOpts::default());
510        runtime.run().unwrap();
511        let machine = RiscvAir::machine(BabyBearPoseidon2::new());
512        let (pkey, _) = machine.setup(&program_clone);
513        let opts = SP1CoreOpts::default();
514        machine.generate_dependencies(
515            &mut runtime.records.clone().into_iter().map(|r| *r).collect::<Vec<_>>(),
516            &opts,
517            None,
518        );
519
520        let shards = runtime.records;
521        debug_interactions_with_all_chips::<BabyBearPoseidon2, RiscvAir<BabyBear>>(
522            &machine,
523            &pkey,
524            &shards.into_iter().map(|r| *r).collect::<Vec<_>>(),
525            vec![InteractionKind::Byte],
526            InteractionScope::Global,
527        );
528    }
529}