sp1_recursion_core/chips/mem/
constant.rs1use core::borrow::Borrow;
2use p3_air::{Air, BaseAir, PairBuilder};
3use p3_field::PrimeField32;
4use p3_matrix::{dense::RowMajorMatrix, Matrix};
5use sp1_derive::AlignedBorrow;
6use sp1_stark::air::MachineAir;
7use std::marker::PhantomData;
8
9#[cfg(feature = "sys")]
10use {
11 itertools::Itertools,
12 sp1_core_machine::utils::pad_rows_fixed,
13 std::{borrow::BorrowMut, iter::zip},
14};
15
16use crate::{builder::SP1RecursionAirBuilder, *};
17
18use super::MemoryAccessCols;
19
20pub const NUM_CONST_MEM_ENTRIES_PER_ROW: usize = 2;
21
22#[derive(Default)]
23pub struct MemoryChip<F> {
24 _marker: PhantomData<F>,
25}
26
27pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
28
29#[derive(AlignedBorrow, Debug, Clone, Copy)]
30#[repr(C)]
31pub struct MemoryCols<F: Copy> {
32 _nothing: F,
34}
35
36pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
37 core::mem::size_of::<MemoryPreprocessedCols<u8>>();
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct MemoryPreprocessedCols<F: Copy> {
42 values_and_accesses: [(Block<F>, MemoryAccessCols<F>); NUM_CONST_MEM_ENTRIES_PER_ROW],
43}
44impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
45 fn width(&self) -> usize {
46 NUM_MEM_INIT_COLS
47 }
48}
49
50impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
51 type Record = crate::ExecutionRecord<F>;
52
53 type Program = crate::RecursionProgram<F>;
54
55 fn name(&self) -> String {
56 "MemoryConst".to_string()
57 }
58 fn preprocessed_width(&self) -> usize {
59 NUM_MEM_PREPROCESSED_INIT_COLS
60 }
61
62 #[cfg(not(feature = "sys"))]
63 fn generate_preprocessed_trace(&self, _program: &Self::Program) -> Option<RowMajorMatrix<F>> {
64 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
65 }
66
67 #[cfg(feature = "sys")]
68 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
69 let mut rows = program
70 .inner
71 .iter()
72 .filter_map(|instruction| match instruction {
73 Instruction::Mem(MemInstr { addrs, vals, mult, kind }) => {
74 let mult = mult.to_owned();
75 let mult = match kind {
76 MemAccessKind::Read => -mult,
77 MemAccessKind::Write => mult,
78 };
79
80 Some((vals.inner, MemoryAccessCols { addr: addrs.inner, mult }))
81 }
82 _ => None,
83 })
84 .chunks(NUM_CONST_MEM_ENTRIES_PER_ROW)
85 .into_iter()
86 .map(|row_vs_as| {
87 let mut row = [F::zero(); NUM_MEM_PREPROCESSED_INIT_COLS];
88 let cols: &mut MemoryPreprocessedCols<_> = row.as_mut_slice().borrow_mut();
89 for (cell, access) in zip(&mut cols.values_and_accesses, row_vs_as) {
90 *cell = access;
91 }
92 row
93 })
94 .collect::<Vec<_>>();
95
96 pad_rows_fixed(
98 &mut rows,
99 || [F::zero(); NUM_MEM_PREPROCESSED_INIT_COLS],
100 program.fixed_log2_rows(self),
101 );
102
103 let trace = RowMajorMatrix::new(
105 rows.into_iter().flatten().collect::<Vec<_>>(),
106 NUM_MEM_PREPROCESSED_INIT_COLS,
107 );
108
109 Some(trace)
110 }
111
112 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
113 }
115
116 #[cfg(not(feature = "sys"))]
117 fn generate_trace(&self, _input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
118 unimplemented!("To generate traces, enable feature `sp1-recursion-core/sys`");
119 }
120
121 #[cfg(feature = "sys")]
122 fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
123 let num_rows = input
125 .mem_const_count
126 .checked_sub(1)
127 .map(|x| x / NUM_CONST_MEM_ENTRIES_PER_ROW + 1)
128 .unwrap_or_default();
129 let mut rows =
130 std::iter::repeat_n([F::zero(); NUM_MEM_INIT_COLS], num_rows).collect::<Vec<_>>();
131
132 pad_rows_fixed(&mut rows, || [F::zero(); NUM_MEM_INIT_COLS], input.fixed_log2_rows(self));
134
135 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS)
137 }
138
139 fn included(&self, _record: &Self::Record) -> bool {
140 true
141 }
142
143 fn local_only(&self) -> bool {
144 true
145 }
146}
147
148impl<AB> Air<AB> for MemoryChip<AB::F>
149where
150 AB: SP1RecursionAirBuilder + PairBuilder,
151{
152 fn eval(&self, builder: &mut AB) {
153 let prep = builder.preprocessed();
154 let prep_local = prep.row_slice(0);
155 let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
156
157 for (value, access) in prep_local.values_and_accesses {
158 builder.send_block(access.addr, value, access.mult);
159 }
160 }
161}
162
163#[cfg(all(test, feature = "sys"))]
164mod tests {
165 use machine::tests::test_recursion_linear_program;
166
167 use super::*;
168
169 use crate::runtime::instruction as instr;
170
171 #[test]
172 pub fn prove_basic_mem() {
173 test_recursion_linear_program(vec![
174 instr::mem(MemAccessKind::Write, 1, 1, 2),
175 instr::mem(MemAccessKind::Read, 1, 1, 2),
176 ]);
177 }
178
179 #[test]
180 #[should_panic]
181 pub fn basic_mem_bad_mult() {
182 test_recursion_linear_program(vec![
183 instr::mem(MemAccessKind::Write, 1, 1, 2),
184 instr::mem(MemAccessKind::Read, 9, 1, 2),
185 ]);
186 }
187
188 #[test]
189 #[should_panic]
190 pub fn basic_mem_bad_address() {
191 test_recursion_linear_program(vec![
192 instr::mem(MemAccessKind::Write, 1, 1, 2),
193 instr::mem(MemAccessKind::Read, 1, 9, 2),
194 ]);
195 }
196
197 #[test]
198 #[should_panic]
199 pub fn basic_mem_bad_value() {
200 test_recursion_linear_program(vec![
201 instr::mem(MemAccessKind::Write, 1, 1, 2),
202 instr::mem(MemAccessKind::Read, 1, 1, 999),
203 ]);
204 }
205}