1use core::{
2 borrow::{Borrow, BorrowMut},
3 mem::size_of,
4};
5use itertools::Itertools;
6use p3_air::{Air, AirBuilder, AirBuilderWithPublicValues, BaseAir, PairBuilder};
7use p3_field::AbstractField;
8use p3_matrix::{dense::RowMajorMatrix, Matrix};
9
10use p3_field::PrimeField32;
11use p3_maybe_rayon::prelude::{ParallelBridge, ParallelIterator};
12use sp1_core_executor::{events::GlobalInteractionEvent, ExecutionRecord, Program};
13use sp1_derive::AlignedBorrow;
14use sp1_stark::{
15 air::{
16 AirInteraction, InteractionScope, MachineAir, PublicValues, SP1AirBuilder,
17 SP1_PROOF_NUM_PV_ELTS,
18 },
19 InteractionKind, Word,
20};
21
22use crate::{
23 operations::IsZeroOperation,
24 utils::{next_power_of_two, pad_rows_fixed, zeroed_f_vec},
25};
26
27pub const NUM_MEMORY_PROGRAM_PREPROCESSED_COLS: usize =
28 size_of::<MemoryProgramPreprocessedCols<u8>>();
29pub const NUM_MEMORY_PROGRAM_MULT_COLS: usize = size_of::<MemoryProgramMultCols<u8>>();
30
31#[derive(AlignedBorrow, Clone, Copy, Default)]
33#[repr(C)]
34pub struct MemoryProgramPreprocessedCols<T> {
35 pub addr: T,
36 pub value: Word<T>,
37 pub is_real: T,
38}
39
40#[derive(AlignedBorrow, Clone, Copy)]
42#[repr(C)]
43pub struct MemoryProgramMultCols<T: Copy> {
44 pub multiplicity: T,
48
49 pub is_first_shard: IsZeroOperation<T>,
51}
52
53#[derive(Default)]
57pub struct MemoryProgramChip;
58
59impl MemoryProgramChip {
60 pub const fn new() -> Self {
61 Self {}
62 }
63}
64
65impl<F: PrimeField32> MachineAir<F> for MemoryProgramChip {
66 type Record = ExecutionRecord;
67
68 type Program = Program;
69
70 fn name(&self) -> String {
71 "MemoryProgram".to_string()
72 }
73
74 fn preprocessed_width(&self) -> usize {
75 NUM_MEMORY_PROGRAM_PREPROCESSED_COLS
76 }
77
78 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
79 let nb_rows = program.memory_image.len();
81 let size_log2 = program.fixed_log2_rows::<F, _>(self);
82 let padded_nb_rows = next_power_of_two(nb_rows, size_log2);
83 let mut values = zeroed_f_vec(padded_nb_rows * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS);
84 let chunk_size = std::cmp::max((nb_rows + 1) / num_cpus::get(), 1);
85
86 let memory = program.memory_image.iter().sorted().collect::<Vec<_>>();
87 values
88 .chunks_mut(chunk_size * NUM_MEMORY_PROGRAM_PREPROCESSED_COLS)
89 .enumerate()
90 .par_bridge()
91 .for_each(|(i, rows)| {
92 rows.chunks_mut(NUM_MEMORY_PROGRAM_PREPROCESSED_COLS).enumerate().for_each(
93 |(j, row)| {
94 let idx = i * chunk_size + j;
95
96 if idx < nb_rows {
97 let (addr, word) = memory[idx];
98 let cols: &mut MemoryProgramPreprocessedCols<F> = row.borrow_mut();
99 cols.addr = F::from_canonical_u32(*addr);
100 cols.value = Word::from(*word);
101 cols.is_real = F::one();
102 }
103 },
104 );
105 });
106
107 Some(RowMajorMatrix::new(values, NUM_MEMORY_PROGRAM_PREPROCESSED_COLS))
109 }
110
111 fn generate_dependencies(&self, input: &ExecutionRecord, output: &mut ExecutionRecord) {
112 let program_memory = &input.program.memory_image;
113
114 let mut events = Vec::new();
115 program_memory.iter().for_each(|(&addr, &word)| {
116 events.push(GlobalInteractionEvent {
117 message: [
118 0,
119 0,
120 addr,
121 word & 255,
122 (word >> 8) & 255,
123 (word >> 16) & 255,
124 (word >> 24) & 255,
125 ],
126 is_receive: false,
127 kind: InteractionKind::Memory as u8,
128 });
129 });
130
131 output.global_interaction_events.extend(events);
132 }
133
134 fn generate_trace(
135 &self,
136 input: &ExecutionRecord,
137 _output: &mut ExecutionRecord,
138 ) -> RowMajorMatrix<F> {
139 let program_memory = &input.program.memory_image;
140
141 let mult_bool = input.public_values.shard == 1;
142 let mult = F::from_bool(mult_bool);
143
144 let mut rows = program_memory
146 .iter()
147 .map(|(&_, &_)| {
148 let mut row = [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS];
149 let cols: &mut MemoryProgramMultCols<F> = row.as_mut_slice().borrow_mut();
150 cols.multiplicity = mult;
151 cols.is_first_shard.populate(input.public_values.shard - 1);
152 row
153 })
154 .collect::<Vec<_>>();
155
156 pad_rows_fixed(
158 &mut rows,
159 || [F::zero(); NUM_MEMORY_PROGRAM_MULT_COLS],
160 input.fixed_log2_rows::<F, _>(self),
161 );
162
163 RowMajorMatrix::new(
165 rows.into_iter().flatten().collect::<Vec<_>>(),
166 NUM_MEMORY_PROGRAM_MULT_COLS,
167 )
168 }
169
170 fn included(&self, _: &Self::Record) -> bool {
171 false
172 }
173
174 fn commit_scope(&self) -> InteractionScope {
175 InteractionScope::Local
176 }
177}
178
179impl<F> BaseAir<F> for MemoryProgramChip {
180 fn width(&self) -> usize {
181 NUM_MEMORY_PROGRAM_MULT_COLS
182 }
183}
184
185impl<AB> Air<AB> for MemoryProgramChip
186where
187 AB: SP1AirBuilder + PairBuilder + AirBuilderWithPublicValues,
188{
189 fn eval(&self, builder: &mut AB) {
190 let preprocessed = builder.preprocessed();
191 let main = builder.main();
192
193 let prep_local = preprocessed.row_slice(0);
194 let prep_local: &MemoryProgramPreprocessedCols<AB::Var> = (*prep_local).borrow();
195
196 let mult_local = main.row_slice(0);
197 let mult_local: &MemoryProgramMultCols<AB::Var> = (*mult_local).borrow();
198
199 let public_values_slice: [AB::Expr; SP1_PROOF_NUM_PV_ELTS] =
201 core::array::from_fn(|i| builder.public_values()[i].into());
202 let public_values: &PublicValues<Word<AB::Expr>, AB::Expr> =
203 public_values_slice.as_slice().borrow();
204
205 IsZeroOperation::<AB::F>::eval(
207 builder,
208 public_values.shard.clone() - AB::F::one(),
209 mult_local.is_first_shard,
210 prep_local.is_real.into(),
211 );
212
213 builder.assert_bool(mult_local.multiplicity);
215
216 builder
218 .when(mult_local.is_first_shard.result)
219 .assert_eq(mult_local.multiplicity, prep_local.is_real.into());
220
221 builder.when_not(mult_local.is_first_shard.result).assert_zero(mult_local.multiplicity);
223
224 let mut values = vec![AB::Expr::zero(), AB::Expr::zero(), prep_local.addr.into()];
225 values.extend(prep_local.value.map(Into::into));
226
227 builder.send(
229 AirInteraction::new(
230 vec![
231 AB::Expr::zero(),
232 AB::Expr::zero(),
233 prep_local.addr.into(),
234 prep_local.value[0].into(),
235 prep_local.value[1].into(),
236 prep_local.value[2].into(),
237 prep_local.value[3].into(),
238 prep_local.is_real.into() * AB::Expr::zero(),
239 prep_local.is_real.into() * AB::Expr::one(),
240 AB::Expr::from_canonical_u8(InteractionKind::Memory as u8),
241 ],
242 prep_local.is_real.into(),
243 InteractionKind::Global,
244 ),
245 InteractionScope::Local,
246 );
247 }
248}