sp1_recursion_machine/chips/mem/
variable.rs1use core::borrow::Borrow;
2use slop_air::{Air, BaseAir, PairBuilder};
3use slop_algebra::PrimeField32;
4use slop_matrix::Matrix;
5use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
6use sp1_derive::AlignedBorrow;
7use sp1_hypercube::{air::MachineAir, next_multiple_of_32};
8use sp1_recursion_executor::{
9 instruction::{HintAddCurveInstr, HintBitsInstr, HintExt2FeltsInstr, HintInstr},
10 Block, ExecutionRecord, Instruction, RecursionProgram,
11};
12use std::{borrow::BorrowMut, iter::zip, marker::PhantomData, mem::MaybeUninit};
13
14use crate::builder::SP1RecursionAirBuilder;
15
16use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS};
17
18#[derive(Default, Clone)]
19pub struct MemoryVarChip<F, const VAR_EVENTS_PER_ROW: usize> {
20 _marker: PhantomData<F>,
21}
22
23pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryVarCols<u8, 1>>();
24
25#[derive(AlignedBorrow, Debug, Clone, Copy)]
26#[repr(C)]
27pub struct MemoryVarCols<F: Copy, const VAR_EVENTS_PER_ROW: usize> {
28 values: [Block<F>; VAR_EVENTS_PER_ROW],
29}
30
31pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
32 core::mem::size_of::<MemoryVarPreprocessedCols<u8, 1>>();
33
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct MemoryVarPreprocessedCols<F: Copy, const VAR_EVENTS_PER_ROW: usize> {
37 accesses: [MemoryAccessCols<F>; VAR_EVENTS_PER_ROW],
38}
39
40impl<F: Send + Sync, const VAR_EVENTS_PER_ROW: usize> BaseAir<F>
41 for MemoryVarChip<F, VAR_EVENTS_PER_ROW>
42{
43 fn width(&self) -> usize {
44 NUM_MEM_INIT_COLS * VAR_EVENTS_PER_ROW
45 }
46}
47
48impl<F: PrimeField32, const VAR_EVENTS_PER_ROW: usize> MachineAir<F>
49 for MemoryVarChip<F, VAR_EVENTS_PER_ROW>
50{
51 type Record = ExecutionRecord<F>;
52
53 type Program = RecursionProgram<F>;
54
55 fn name(&self) -> &'static str {
56 "MemoryVar"
57 }
58 fn preprocessed_width(&self) -> usize {
59 NUM_MEM_PREPROCESSED_INIT_COLS * VAR_EVENTS_PER_ROW
60 }
61
62 fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
63 let instrs_len = program
64 .inner
65 .iter()
66 .flat_map(|instruction| match instruction.inner() {
68 Instruction::Hint(HintInstr { output_addrs_mults })
69 | Instruction::HintBits(HintBitsInstr {
70 output_addrs_mults,
71 input_addr: _, }) => output_addrs_mults.iter().collect(),
73 Instruction::HintExt2Felts(HintExt2FeltsInstr {
74 output_addrs_mults,
75 input_addr: _, }) => output_addrs_mults.iter().collect(),
77 Instruction::HintAddCurve(instr) => {
78 let HintAddCurveInstr {
79 output_x_addrs_mults,
80 output_y_addrs_mults, .. } = instr.as_ref();
82 output_x_addrs_mults.iter().chain(output_y_addrs_mults.iter()).collect()
83 }
84 _ => vec![],
85 })
86 .count();
87 self.preprocessed_num_rows_with_instrs_len(program, instrs_len)
88 }
89
90 fn preprocessed_num_rows_with_instrs_len(
91 &self,
92 program: &Self::Program,
93 instrs_len: usize,
94 ) -> Option<usize> {
95 let height = program.shape.as_ref().and_then(|shape| shape.height(self));
96 Some(next_multiple_of_32(instrs_len.div_ceil(VAR_EVENTS_PER_ROW), height))
97 }
98
99 fn generate_preprocessed_trace_into(
100 &self,
101 program: &Self::Program,
102 buffer: &mut [MaybeUninit<F>],
103 ) {
104 let accesses = program
106 .inner
107 .iter()
108 .flat_map(|instruction| match instruction.inner() {
110 Instruction::Hint(HintInstr { output_addrs_mults })
111 | Instruction::HintBits(HintBitsInstr {
112 output_addrs_mults,
113 input_addr: _, }) => output_addrs_mults.iter().collect(),
115 Instruction::HintExt2Felts(HintExt2FeltsInstr {
116 output_addrs_mults,
117 input_addr: _, }) => output_addrs_mults.iter().collect(),
119 Instruction::HintAddCurve(instr) => {
120 let HintAddCurveInstr {
121 output_x_addrs_mults,
122 output_y_addrs_mults, .. } = instr.as_ref();
124 output_x_addrs_mults.iter().chain(output_y_addrs_mults.iter()).collect()
125 }
126 _ => vec![],
127 })
128 .collect::<Vec<_>>();
129
130 let padded_nb_rows =
131 self.preprocessed_num_rows_with_instrs_len(program, accesses.len()).unwrap();
132
133 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
134 let values = unsafe {
135 core::slice::from_raw_parts_mut(
136 buffer_ptr,
137 padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS * VAR_EVENTS_PER_ROW,
138 )
139 };
140
141 unsafe {
142 let padding_start = accesses.len() * NUM_MEM_ACCESS_COLS;
143 let padding_size = padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS * VAR_EVENTS_PER_ROW
144 - padding_start;
145 if padding_size > 0 {
146 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
147 }
148 }
149
150 let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS;
152 values[..populate_len]
153 .par_chunks_mut(NUM_MEM_ACCESS_COLS)
154 .zip_eq(accesses)
155 .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult });
156 }
157
158 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
159 }
161
162 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
163 let height = input.program.shape.as_ref().and_then(|shape| shape.height(self));
164 let nb_rows = input.mem_var_events.len().div_ceil(VAR_EVENTS_PER_ROW);
165 let padded_nb_rows = next_multiple_of_32(nb_rows, height);
166 Some(padded_nb_rows)
167 }
168
169 fn generate_trace_into(
170 &self,
171 input: &ExecutionRecord<F>,
172 _: &mut ExecutionRecord<F>,
173 buffer: &mut [MaybeUninit<F>],
174 ) {
175 let padded_nb_rows = self.num_rows(input).unwrap();
176 let events = &input.mem_var_events;
177 let num_events = events.len();
178
179 unsafe {
180 let padding_start = num_events * NUM_MEM_INIT_COLS;
181 let padding_size =
182 padded_nb_rows * NUM_MEM_INIT_COLS * VAR_EVENTS_PER_ROW - padding_start;
183 if padding_size > 0 {
184 core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
185 }
186 }
187
188 let buffer_ptr = buffer.as_mut_ptr() as *mut F;
189 let values =
190 unsafe { core::slice::from_raw_parts_mut(buffer_ptr, num_events * NUM_MEM_INIT_COLS) };
191
192 let populate_len = events.len() * NUM_MEM_INIT_COLS;
194 values[..populate_len].par_chunks_mut(NUM_MEM_INIT_COLS).zip_eq(events).for_each(
195 |(row, &vals)| {
196 let cols: &mut Block<F> = row.borrow_mut();
197 *cols = vals.inner;
198 },
199 );
200 }
201
202 fn included(&self, _record: &Self::Record) -> bool {
203 true
204 }
205}
206
207impl<AB, const VAR_EVENTS_PER_ROW: usize> Air<AB> for MemoryVarChip<AB::F, VAR_EVENTS_PER_ROW>
208where
209 AB: SP1RecursionAirBuilder + PairBuilder,
210{
211 fn eval(&self, builder: &mut AB) {
212 let main = builder.main();
213 let local = main.row_slice(0);
214 let local: &MemoryVarCols<AB::Var, VAR_EVENTS_PER_ROW> = (*local).borrow();
215 let prep = builder.preprocessed();
216 let prep_local = prep.row_slice(0);
217 let prep_local: &MemoryVarPreprocessedCols<AB::Var, VAR_EVENTS_PER_ROW> =
218 (*prep_local).borrow();
219
220 for (value, access) in zip(local.values, prep_local.accesses) {
221 builder.send_block(access.addr, value, access.mult);
222 }
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 #![allow(clippy::print_stdout)]
229
230 use slop_algebra::AbstractField;
231
232 use slop_matrix::dense::RowMajorMatrix;
233 use sp1_primitives::SP1Field;
234 use sp1_recursion_executor::MemEvent;
235
236 use crate::chips::test_fixtures;
237
238 use super::*;
239
240 #[tokio::test]
241 async fn generate_trace() {
242 let shard = test_fixtures::shard().await;
243 let chip = MemoryVarChip::<_, 2>::default();
244 let trace = chip.generate_trace(shard, &mut ExecutionRecord::default());
245 assert!(trace.height() > test_fixtures::MIN_ROWS);
246 }
247
248 #[tokio::test]
249 async fn generate_preprocessed_trace() {
250 let program = &test_fixtures::program_with_input().await.0;
251 let chip = MemoryVarChip::<_, 2>::default();
252 let trace = chip.generate_preprocessed_trace(program).unwrap();
253 assert!(trace.height() > test_fixtures::MIN_ROWS);
254 }
255
256 #[test]
257 pub fn generate_trace_simple() {
258 let shard = ExecutionRecord::<SP1Field> {
259 mem_var_events: vec![
260 MemEvent { inner: SP1Field::one().into() },
261 MemEvent { inner: SP1Field::one().into() },
262 ],
263 ..Default::default()
264 };
265 let chip = MemoryVarChip::<_, 2>::default();
266 let trace: RowMajorMatrix<SP1Field> =
267 chip.generate_trace(&shard, &mut ExecutionRecord::default());
268 println!("{:?}", trace.values)
269 }
270}