sp1_recursion_core_v2/chips/mem/
variable.rs

1use core::borrow::Borrow;
2use instruction::{HintBitsInstr, HintExt2FeltsInstr, HintInstr};
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::{next_power_of_two, pad_to_power_of_two};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::MachineAir;
10use std::{borrow::BorrowMut, iter::zip, marker::PhantomData};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS};
15
16pub const NUM_MEM_ENTRIES_PER_ROW: usize = 16;
17
18#[derive(Default)]
19pub struct MemoryChip<F> {
20    _data: PhantomData<F>,
21}
22
23pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
24
25#[derive(AlignedBorrow, Debug, Clone, Copy)]
26#[repr(C)]
27pub struct MemoryCols<F: Copy> {
28    values: [Block<F>; NUM_MEM_ENTRIES_PER_ROW],
29}
30
31pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
32    core::mem::size_of::<MemoryPreprocessedCols<u8>>();
33
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct MemoryPreprocessedCols<F: Copy> {
37    accesses: [MemoryAccessCols<F>; NUM_MEM_ENTRIES_PER_ROW],
38}
39
40impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
41    fn width(&self) -> usize {
42        NUM_MEM_INIT_COLS
43    }
44}
45
46impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
47    type Record = crate::ExecutionRecord<F>;
48
49    type Program = crate::RecursionProgram<F>;
50
51    fn name(&self) -> String {
52        "MemoryVar".to_string()
53    }
54    fn preprocessed_width(&self) -> usize {
55        NUM_MEM_PREPROCESSED_INIT_COLS
56    }
57
58    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
59        // Allocating an intermediate `Vec` is faster.
60        let accesses = program
61            .instructions
62            .par_iter() // Using `rayon` here provides a big speedup.
63            .flat_map_iter(|instruction| match instruction {
64                Instruction::Hint(HintInstr { output_addrs_mults })
65                | Instruction::HintBits(HintBitsInstr {
66                    output_addrs_mults,
67                    input_addr: _, // No receive interaction for the hint operation
68                }) => output_addrs_mults.iter().collect(),
69                Instruction::HintExt2Felts(HintExt2FeltsInstr {
70                    output_addrs_mults,
71                    input_addr: _, // No receive interaction for the hint operation
72                }) => output_addrs_mults.iter().collect(),
73                _ => vec![],
74            })
75            .collect::<Vec<_>>();
76
77        let nb_rows = accesses.len().div_ceil(NUM_MEM_ENTRIES_PER_ROW);
78        let padded_nb_rows = next_power_of_two(nb_rows, None);
79        let mut values = vec![F::zero(); padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS];
80        // Generate the trace rows & corresponding records for each chunk of events in parallel.
81        let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS;
82        values[..populate_len]
83            .par_chunks_mut(NUM_MEM_ACCESS_COLS)
84            .zip_eq(accesses)
85            .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult });
86
87        Some(RowMajorMatrix::new(values, NUM_MEM_PREPROCESSED_INIT_COLS))
88    }
89
90    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
91        // This is a no-op.
92    }
93
94    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
95        // Generate the trace rows & corresponding records for each chunk of events in parallel.
96        let rows = input
97            .mem_var_events
98            .chunks(NUM_MEM_ENTRIES_PER_ROW)
99            .map(|row_events| {
100                let mut row = [F::zero(); NUM_MEM_INIT_COLS];
101                let cols: &mut MemoryCols<_> = row.as_mut_slice().borrow_mut();
102                for (cell, vals) in zip(&mut cols.values, row_events) {
103                    *cell = vals.inner;
104                }
105                row
106            })
107            .collect::<Vec<_>>();
108
109        // Convert the trace to a row major matrix.
110        let mut trace =
111            RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS);
112
113        // Pad the trace to a power of two.
114        pad_to_power_of_two::<NUM_MEM_INIT_COLS, F>(&mut trace.values);
115
116        trace
117    }
118
119    fn included(&self, _record: &Self::Record) -> bool {
120        true
121    }
122}
123
124impl<AB> Air<AB> for MemoryChip<AB::F>
125where
126    AB: SP1RecursionAirBuilder + PairBuilder,
127{
128    fn eval(&self, builder: &mut AB) {
129        let main = builder.main();
130        let local = main.row_slice(0);
131        let local: &MemoryCols<AB::Var> = (*local).borrow();
132        let prep = builder.preprocessed();
133        let prep_local = prep.row_slice(0);
134        let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
135
136        for (value, access) in zip(local.values, prep_local.accesses) {
137            builder.send_block(access.addr, value, access.mult);
138        }
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use machine::tests::run_recursion_test_machines;
145    use p3_baby_bear::BabyBear;
146    use p3_field::AbstractField;
147    use p3_matrix::dense::RowMajorMatrix;
148
149    use super::*;
150
151    use crate::runtime::instruction as instr;
152
153    #[test]
154    pub fn generate_trace() {
155        let shard = ExecutionRecord::<BabyBear> {
156            mem_var_events: vec![
157                MemEvent { inner: BabyBear::one().into() },
158                MemEvent { inner: BabyBear::one().into() },
159            ],
160            ..Default::default()
161        };
162        let chip = MemoryChip::default();
163        let trace: RowMajorMatrix<BabyBear> =
164            chip.generate_trace(&shard, &mut ExecutionRecord::default());
165        println!("{:?}", trace.values)
166    }
167
168    #[test]
169    pub fn prove_basic_mem() {
170        let program = RecursionProgram {
171            instructions: vec![
172                instr::mem(MemAccessKind::Write, 1, 1, 2),
173                instr::mem(MemAccessKind::Read, 1, 1, 2),
174            ],
175            ..Default::default()
176        };
177
178        run_recursion_test_machines(program);
179    }
180
181    #[test]
182    #[should_panic]
183    pub fn basic_mem_bad_mult() {
184        let program = RecursionProgram {
185            instructions: vec![
186                instr::mem(MemAccessKind::Write, 1, 1, 2),
187                instr::mem(MemAccessKind::Read, 999, 1, 2),
188            ],
189            ..Default::default()
190        };
191
192        run_recursion_test_machines(program);
193    }
194
195    #[test]
196    #[should_panic]
197    pub fn basic_mem_bad_address() {
198        let program = RecursionProgram {
199            instructions: vec![
200                instr::mem(MemAccessKind::Write, 1, 1, 2),
201                instr::mem(MemAccessKind::Read, 1, 999, 2),
202            ],
203            ..Default::default()
204        };
205
206        run_recursion_test_machines(program);
207    }
208
209    #[test]
210    #[should_panic]
211    pub fn basic_mem_bad_value() {
212        let program = RecursionProgram {
213            instructions: vec![
214                instr::mem(MemAccessKind::Write, 1, 1, 2),
215                instr::mem(MemAccessKind::Read, 1, 1, 999),
216            ],
217            ..Default::default()
218        };
219
220        run_recursion_test_machines(program);
221    }
222}