sp1_core_machine/memory/
program.rs

1use core::{
2    borrow::{Borrow, BorrowMut},
3    mem::size_of,
4};
5use itertools::Itertools;
6use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder};
7use p3_field::AbstractField;
8use p3_matrix::{dense::RowMajorMatrix, Matrix};
9
10use p3_field::PrimeField32;
11use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
12use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord, Program};
13use sp1_derive::AlignedBorrow;
14use sp1_stark::{
15    air::{
16        AirInteraction, InteractionScope, MachineAir, PublicValues, SP1AirBuilder,
17        SP1_PROOF_NUM_PV_ELTS,
18    },
19    InteractionKind, Word,
20};
21
22use crate::{
23    operations::IsZeroOperation,
24    utils::{next_power_of_two, pad_rows_fixed, zeroed_f_vec},
25};
26
27pub const NUM_MEMORY_PROGRAM_PREPROCESSED_COLS: usize =
28    size_of::<MemoryProgramPreprocessedCols<u8>>();
29pub const NUM_MEMORY_PROGRAM_MULT_COLS: usize = size_of::<MemoryProgramMultCols<u8>>();
30
31/// The column layout for the chip.
32#[derive(AlignedBorrow, Clone, Copy, Default)]
33#[repr(C)]
34pub struct MemoryProgramPreprocessedCols<T> {
35    pub addr: T,
36    pub value: Word<T>,
37    pub is_real: T,
38}
39
40/// Multiplicity columns.
41#[derive(AlignedBorrow, Clone, Copy)]
42#[repr(C)]
43pub struct MemoryProgramMultCols<T: Copy> {
44    /// The multiplicity of the event.
45    ///
46    /// This column is technically redundant with `is_real`, but it's included for clarity.
47    pub multiplicity: T,
48
49    /// Whether the shard is the first shard.
50    pub is_first_shard: IsZeroOperation<T>,
51}
52
53/// Chip that initializes memory that is provided from the program. The table is preprocessed and
54/// receives each row in the first shard. This prevents any of these addresses from being
55/// overwritten through the normal MemoryInit.
56#[derive(Default)]
57pub struct MemoryProgramChip;
58
59impl MemoryProgramChip {
60    pub const fn new() -> Self {
61        Self {}
62    }
63}
64
65impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
66    type Record = ExecutionRecord;
67
68    type Program = Program;
69
70    fn name(&self) -> String {
71        "MemoryProgram".to_string()
72    }
73
74    fn preprocessed_width(&self) -> usize {
75        NUM_MEMORY_PROGRAM_PREPROCESSED_COLS
76    }
77
78    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
79        // Generate the trace rows for each event.
80        let nb_rows = program.memory_image.len();
81        let size_log2 = program.fixed_log2_rows::<F, _>(self);
82        let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
83        let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS);
84        let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1);
85
86        let memory = program.memory_image.iter().sorted().collect::<Vec<_>>();
87        values
88            .chunks_mut(chunk_size * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS)
89            .enumerate()
90            .par_bridge()
91            .for_each(|(i, rows)| {
92                rows.chunks_mut(NUM_MEMORY_PROGRAM_PREPROCESSED_COLS).enumerate().for_each(
93                    |(j, row)| {
94                        let idx = i * chunk_size + j;
95
96                        if idx < nb_rows {
97                            let (addr, word) = memory[idx];
98                            let cols: &mut MemoryProgramPreprocessedCols<F> = row.borrow_mut();
99                            cols.addr = F::from_canonical_u32(*addr);
100                            cols.value = Word::from(*word);
101                            cols.is_real = F::one();
102                        }
103                    },
104                );
105            });
106
107        // Convert the trace to a row major matrix.
108        Some(RowMajorMatrix::new(values, NUM_MEMORY_PROGRAM_PREPROCESSED_COLS))
109    }
110
111    fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
112        let program_memory = &input.program.memory_image;
113
114        let mut events = Vec::new();
115        program_memory.iter().for_each(|(&addr, &word)| {
116            events.push(GlobalInteractionEvent {
117                message: [
118                    0,
119                    0,
120                    addr,
121                    word & 255,
122                    (word >> 8) & 255,
123                    (word >> 16) & 255,
124                    (word >> 24) & 255,
125                ],
126                is_receive: false,
127                kind: InteractionKind::Memory as u8,
128            });
129        });
130
131        output.global_interaction_events.extend(events);
132    }
133
134    fn generate_trace(
135        &self,
136        input: &ExecutionRecord,
137        _output: &mut ExecutionRecord,
138    ) -> RowMajorMatrix<F> {
139        let program_memory = &input.program.memory_image;
140
141        let mult_bool = input.public_values.shard == 1;
142        let mult = F::from_bool(mult_bool);
143
144        // Generate the trace rows for each event.
145        let mut rows = program_memory
146            .iter()
147            .map(|(&_, &_)| {
148                let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS];
149                let cols: &mut MemoryProgramMultCols<F> = row.as_mut_slice().borrow_mut();
150                cols.multiplicity = mult;
151                cols.is_first_shard.populate(input.public_values.shard - 1);
152                row
153            })
154            .collect::<Vec<_>>();
155
156        // Pad the trace to a power of two depending on the proof shape in `input`.
157        pad_rows_fixed(
158            &mut rows,
159            || [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS],
160            input.fixed_log2_rows::<F, _>(self),
161        );
162
163        // Convert the trace to a row major matrix.
164        RowMajorMatrix::new(
165            rows.into_iter().flatten().collect::<Vec<_>>(),
166            NUM_MEMORY_PROGRAM_MULT_COLS,
167        )
168    }
169
170    fn included(&self, _: &Self::Record) -> bool {
171        false
172    }
173
174    fn commit_scope(&self) -> InteractionScope {
175        InteractionScope::Local
176    }
177}
178
179impl<F> BaseAir<F> for MemoryProgramChip {
180    fn width(&self) -> usize {
181        NUM_MEMORY_PROGRAM_MULT_COLS
182    }
183}
184
185impl<AB> Air<AB> for MemoryProgramChip
186where
187    AB: SP1AirBuilder + PairBuilder + AirBuilderWithPublicValues,
188{
189    fn eval(&self, builder: &mut AB) {
190        let preprocessed = builder.preprocessed();
191        let main = builder.main();
192
193        let prep_local = preprocessed.row_slice(0);
194        let prep_local: &MemoryProgramPreprocessedCols<AB::Var> = (*prep_local).borrow();
195
196        let mult_local = main.row_slice(0);
197        let mult_local: &MemoryProgramMultCols<AB::Var> = (*mult_local).borrow();
198
199        // Get shard from public values and evaluate whether it is the first shard.
200        let public_values_slice: [AB::Expr; SP1_PROOF_NUM_PV_ELTS] =
201            core::array::from_fn(|i| builder.public_values()[i].into());
202        let public_values: &PublicValues<Word<AB::Expr>, AB::Expr> =
203            public_values_slice.as_slice().borrow();
204
205        // Constrain `is_first_shard` to be 1 if and only if the shard is the first shard.
206        IsZeroOperation::<AB::F>::eval(
207            builder,
208            public_values.shard.clone() - AB::F::one(),
209            mult_local.is_first_shard,
210            prep_local.is_real.into(),
211        );
212
213        // Multiplicity must be either 0 or 1.
214        builder.assert_bool(mult_local.multiplicity);
215
216        // If first shard and preprocessed is real, multiplicity must be one.
217        builder
218            .when(mult_local.is_first_shard.result)
219            .assert_eq(mult_local.multiplicity, prep_local.is_real.into());
220
221        // If it's not the first shard, then the multiplicity must be zero.
222        builder.when_not(mult_local.is_first_shard.result).assert_zero(mult_local.multiplicity);
223
224        let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()];
225        values.extend(prep_local.value.map(Into::into));
226
227        // Send the interaction to the global table.
228        builder.send(
229            AirInteraction::new(
230                vec![
231                    AB::Expr::zero(),
232                    AB::Expr::zero(),
233                    prep_local.addr.into(),
234                    prep_local.value[0].into(),
235                    prep_local.value[1].into(),
236                    prep_local.value[2].into(),
237                    prep_local.value[3].into(),
238                    prep_local.is_real.into() * AB::Expr::zero(),
239                    prep_local.is_real.into() * AB::Expr::one(),
240                    AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
241                ],
242                prep_local.is_real.into(),
243                InteractionKind::Global,
244            ),
245            InteractionScope::Local,
246        );
247    }
248}