sp1_recursion_core_v2/chips/mem/
variable.rs1use core::borrow::Borrow;
2use instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr};
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::{next_power_of_two, pad_to_power_of_two};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::MachineAir;
10use std::{borrow::BorrowMut, iter::zip, marker::PhantomData};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS};
15
16pub const NUM_MEM_ENTRIES_PER_ROW: usize = 16;
17
18#[derive(Default)]
19pub struct MemoryChip<F> {
20 _data: PhantomData<F>,
21}
22
23pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
24
25#[derive(AlignedBorrow, Debug, Clone, Copy)]
26#[repr(C)]
27pub struct MemoryCols<F: Copy> {
28 values: [Block<F>; NUM_MEM_ENTRIES_PER_ROW],
29}
30
31pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
32 core::mem::size_of::<MemoryPreprocessedCols<u8>>();
33
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct MemoryPreprocessedCols<F: Copy> {
37 accesses: [MemoryAccessCols<F>; NUM_MEM_ENTRIES_PER_ROW],
38}
39
40impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
41 fn width(&self) -> usize {
42 NUM_MEM_INIT_COLS
43 }
44}
45
46impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
47 type Record = crate::ExecutionRecord<F>;
48
49 type Program = crate::RecursionProgram<F>;
50
51 fn name(&self) -> String {
52 "MemoryVar".to_string()
53 }
54 fn preprocessed_width(&self) -> usize {
55 NUM_MEM_PREPROCESSED_INIT_COLS
56 }
57
58 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
59 let accesses = program
61 .instructions
62 .par_iter() .flat_map_iter(|instruction| match instruction {
64 Instruction::Hint(HintInstr { output_addrs_mults })
65 | Instruction::HintBits(HintBitsInstr {
66 output_addrs_mults,
67 input_addr: _, }) => output_addrs_mults.iter().collect(),
69 Instruction::HintExt2Felts(HintExt2FeltsInstr {
70 output_addrs_mults,
71 input_addr: _, }) => output_addrs_mults.iter().collect(),
73 _ => vec![],
74 })
75 .collect::<Vec<_>>();
76
77 let nb_rows = accesses.len().div_ceil(NUM_MEM_ENTRIES_PER_ROW);
78 let padded_nb_rows = next_power_of_two(nb_rows, None);
79 let mut values = vec![F::zero(); padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS];
80 let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS;
82 values[..populate_len]
83 .par_chunks_mut(NUM_MEM_ACCESS_COLS)
84 .zip_eq(accesses)
85 .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult });
86
87 Some(RowMajorMatrix::new(values, NUM_MEM_PREPROCESSED_INIT_COLS))
88 }
89
90 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
91 }
93
94 fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
95 let rows = input
97 .mem_var_events
98 .chunks(NUM_MEM_ENTRIES_PER_ROW)
99 .map(|row_events| {
100 let mut row = [F::zero(); NUM_MEM_INIT_COLS];
101 let cols: &mut MemoryCols<_> = row.as_mut_slice().borrow_mut();
102 for (cell, vals) in zip(&mut cols.values, row_events) {
103 *cell = vals.inner;
104 }
105 row
106 })
107 .collect::<Vec<_>>();
108
109 let mut trace =
111 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS);
112
113 pad_to_power_of_two::<NUM_MEM_INIT_COLS, F>(&mut trace.values);
115
116 trace
117 }
118
119 fn included(&self, _record: &Self::Record) -> bool {
120 true
121 }
122}
123
124impl<AB> Air<AB> for MemoryChip<AB::F>
125where
126 AB: SP1RecursionAirBuilder + PairBuilder,
127{
128 fn eval(&self, builder: &mut AB) {
129 let main = builder.main();
130 let local = main.row_slice(0);
131 let local: &MemoryCols<AB::Var> = (*local).borrow();
132 let prep = builder.preprocessed();
133 let prep_local = prep.row_slice(0);
134 let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
135
136 for (value, access) in zip(local.values, prep_local.accesses) {
137 builder.send_block(access.addr, value, access.mult);
138 }
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use machine::tests::run_recursion_test_machines;
145 use p3_baby_bear::BabyBear;
146 use p3_field::AbstractField;
147 use p3_matrix::dense::RowMajorMatrix;
148
149 use super::*;
150
151 use crate::runtime::instruction as instr;
152
153 #[test]
154 pub fn generate_trace() {
155 let shard = ExecutionRecord::<BabyBear> {
156 mem_var_events: vec![
157 MemEvent { inner: BabyBear::one().into() },
158 MemEvent { inner: BabyBear::one().into() },
159 ],
160 ..Default::default()
161 };
162 let chip = MemoryChip::default();
163 let trace: RowMajorMatrix<BabyBear> =
164 chip.generate_trace(&shard, &mut ExecutionRecord::default());
165 println!("{:?}", trace.values)
166 }
167
168 #[test]
169 pub fn prove_basic_mem() {
170 let program = RecursionProgram {
171 instructions: vec![
172 instr::mem(MemAccessKind::Write, 1, 1, 2),
173 instr::mem(MemAccessKind::Read, 1, 1, 2),
174 ],
175 ..Default::default()
176 };
177
178 run_recursion_test_machines(program);
179 }
180
181 #[test]
182 #[should_panic]
183 pub fn basic_mem_bad_mult() {
184 let program = RecursionProgram {
185 instructions: vec![
186 instr::mem(MemAccessKind::Write, 1, 1, 2),
187 instr::mem(MemAccessKind::Read, 999, 1, 2),
188 ],
189 ..Default::default()
190 };
191
192 run_recursion_test_machines(program);
193 }
194
195 #[test]
196 #[should_panic]
197 pub fn basic_mem_bad_address() {
198 let program = RecursionProgram {
199 instructions: vec![
200 instr::mem(MemAccessKind::Write, 1, 1, 2),
201 instr::mem(MemAccessKind::Read, 1, 999, 2),
202 ],
203 ..Default::default()
204 };
205
206 run_recursion_test_machines(program);
207 }
208
209 #[test]
210 #[should_panic]
211 pub fn basic_mem_bad_value() {
212 let program = RecursionProgram {
213 instructions: vec![
214 instr::mem(MemAccessKind::Write, 1, 1, 2),
215 instr::mem(MemAccessKind::Read, 1, 1, 999),
216 ],
217 ..Default::default()
218 };
219
220 run_recursion_test_machines(program);
221 }
222}