sp1_recursion_core/chips/mem/
constant.rs

1use core::borrow::Borrow;
2use p3_air::{Air, BaseAir, PairBuilder};
3use p3_field::PrimeField32;
4use p3_matrix::{dense::RowMajorMatrix, Matrix};
5use sp1_derive::AlignedBorrow;
6use sp1_stark::air::MachineAir;
7use std::marker::PhantomData;
8
9#[cfg(feature = "sys")]
10use {
11    itertools::Itertools,
12    sp1_core_machine::utils::pad_rows_fixed,
13    std::{borrow::BorrowMut, iter::zip},
14};
15
16use crate::{builder::SP1RecursionAirBuilder, *};
17
18use super::MemoryAccessCols;
19
20pub const NUM_CONST_MEM_ENTRIES_PER_ROW: usize = 2;
21
22#[derive(Default)]
23pub struct MemoryChip<F> {
24    _marker: PhantomData<F>,
25}
26
27pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
28
29#[derive(AlignedBorrow, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct MemoryCols<F: Copy> {
32    // At least one column is required, otherwise a bunch of things break.
33    _nothing: F,
34}
35
36pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
37    core::mem::size_of::<MemoryPreprocessedCols<u8>>();
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct MemoryPreprocessedCols<F: Copy> {
42    values_and_accesses: [(Block<F>, MemoryAccessCols<F>); NUM_CONST_MEM_ENTRIES_PER_ROW],
43}
44impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
45    fn width(&self) -> usize {
46        NUM_MEM_INIT_COLS
47    }
48}
49
50impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
51    type Record = crate::ExecutionRecord<F>;
52
53    type Program = crate::RecursionProgram<F>;
54
55    fn name(&self) -> String {
56        "MemoryConst".to_string()
57    }
58    fn preprocessed_width(&self) -> usize {
59        NUM_MEM_PREPROCESSED_INIT_COLS
60    }
61
62    #[cfg(not(feature = "sys"))]
63    fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
64        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
65    }
66
67    #[cfg(feature = "sys")]
68    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
69        let mut rows = program
70            .inner
71            .iter()
72            .filter_map(|instruction| match instruction {
73                Instruction::Mem(MemInstr { addrs, vals, mult, kind }) => {
74                    let mult = mult.to_owned();
75                    let mult = match kind {
76                        MemAccessKind::Read => -mult,
77                        MemAccessKind::Write => mult,
78                    };
79
80                    Some((vals.inner, MemoryAccessCols { addr: addrs.inner, mult }))
81                }
82                _ => None,
83            })
84            .chunks(NUM_CONST_MEM_ENTRIES_PER_ROW)
85            .into_iter()
86            .map(|row_vs_as| {
87                let mut row = [F::zero(); NUM_MEM_PREPROCESSED_INIT_COLS];
88                let cols: &mut MemoryPreprocessedCols<_> = row.as_mut_slice().borrow_mut();
89                for (cell, access) in zip(&mut cols.values_and_accesses, row_vs_as) {
90                    *cell = access;
91                }
92                row
93            })
94            .collect::<Vec<_>>();
95
96        // Pad the rows to the next power of two.
97        pad_rows_fixed(
98            &mut rows,
99            || [F::zero(); NUM_MEM_PREPROCESSED_INIT_COLS],
100            program.fixed_log2_rows(self),
101        );
102
103        // Convert the trace to a row major matrix.
104        let trace = RowMajorMatrix::new(
105            rows.into_iter().flatten().collect::<Vec<_>>(),
106            NUM_MEM_PREPROCESSED_INIT_COLS,
107        );
108
109        Some(trace)
110    }
111
112    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
113        // This is a no-op.
114    }
115
116    #[cfg(not(feature = "sys"))]
117    fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
118        unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
119    }
120
121    #[cfg(feature = "sys")]
122    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
123        // Match number of rows generated by the `.chunks` call in `generate_preprocessed_trace`.
124        let num_rows = input
125            .mem_const_count
126            .checked_sub(1)
127            .map(|x| x / NUM_CONST_MEM_ENTRIES_PER_ROW + 1)
128            .unwrap_or_default();
129        let mut rows =
130            std::iter::repeat_n([F::zero(); NUM_MEM_INIT_COLS], num_rows).collect::<Vec<_>>();
131
132        // Pad the rows to the next power of two.
133        pad_rows_fixed(&mut rows, || [F::zero(); NUM_MEM_INIT_COLS], input.fixed_log2_rows(self));
134
135        // Convert the trace to a row major matrix.
136        RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS)
137    }
138
139    fn included(&self, _record: &Self::Record) -> bool {
140        true
141    }
142
143    fn local_only(&self) -> bool {
144        true
145    }
146}
147
148impl<AB> Air<AB> for MemoryChip<AB::F>
149where
150    AB: SP1RecursionAirBuilder + PairBuilder,
151{
152    fn eval(&self, builder: &mut AB) {
153        let prep = builder.preprocessed();
154        let prep_local = prep.row_slice(0);
155        let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
156
157        for (value, access) in prep_local.values_and_accesses {
158            builder.send_block(access.addr, value, access.mult);
159        }
160    }
161}
162
163#[cfg(all(test, feature = "sys"))]
164mod tests {
165    use machine::tests::test_recursion_linear_program;
166
167    use super::*;
168
169    use crate::runtime::instruction as instr;
170
171    #[test]
172    pub fn prove_basic_mem() {
173        test_recursion_linear_program(vec![
174            instr::mem(MemAccessKind::Write, 1, 1, 2),
175            instr::mem(MemAccessKind::Read, 1, 1, 2),
176        ]);
177    }
178
179    #[test]
180    #[should_panic]
181    pub fn basic_mem_bad_mult() {
182        test_recursion_linear_program(vec![
183            instr::mem(MemAccessKind::Write, 1, 1, 2),
184            instr::mem(MemAccessKind::Read, 9, 1, 2),
185        ]);
186    }
187
188    #[test]
189    #[should_panic]
190    pub fn basic_mem_bad_address() {
191        test_recursion_linear_program(vec![
192            instr::mem(MemAccessKind::Write, 1, 1, 2),
193            instr::mem(MemAccessKind::Read, 1, 9, 2),
194        ]);
195    }
196
197    #[test]
198    #[should_panic]
199    pub fn basic_mem_bad_value() {
200        test_recursion_linear_program(vec![
201            instr::mem(MemAccessKind::Write, 1, 1, 2),
202            instr::mem(MemAccessKind::Read, 1, 1, 999),
203        ]);
204    }
205}