sp1_recursion_core/chips/mem/
variable.rs1use core::borrow::Borrow;
2use instruction::{HintAddCurveInstr, HintBitsInstr, HintExt2FeltsInstr, HintInstr};
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::{next_power_of_two, pad_rows_fixed};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::MachineAir;
10use std::{borrow::BorrowMut, iter::zip, marker::PhantomData};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS};
15
16pub const NUM_VAR_MEM_ENTRIES_PER_ROW: usize = 2;
17
18#[derive(Default)]
19pub struct MemoryChip<F> {
20 _marker: PhantomData<F>,
21}
22
23pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
24
25#[derive(AlignedBorrow, Debug, Clone, Copy)]
26#[repr(C)]
27pub struct MemoryCols<F: Copy> {
28 values: [Block<F>; NUM_VAR_MEM_ENTRIES_PER_ROW],
29}
30
31pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
32 core::mem::size_of::<MemoryPreprocessedCols<u8>>();
33
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct MemoryPreprocessedCols<F: Copy> {
37 accesses: [MemoryAccessCols<F>; NUM_VAR_MEM_ENTRIES_PER_ROW],
38}
39
40impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
41 fn width(&self) -> usize {
42 NUM_MEM_INIT_COLS
43 }
44}
45
46impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
47 type Record = crate::ExecutionRecord<F>;
48
49 type Program = crate::RecursionProgram<F>;
50
51 fn name(&self) -> String {
52 "MemoryVar".to_string()
53 }
54 fn preprocessed_width(&self) -> usize {
55 NUM_MEM_PREPROCESSED_INIT_COLS
56 }
57
58 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
59 let accesses = program
61 .inner
62 .iter()
63 .flat_map(|instruction| match instruction {
65 Instruction::Hint(HintInstr { output_addrs_mults }) |
66 Instruction::HintBits(HintBitsInstr {
67 output_addrs_mults,
68 input_addr: _, }) => output_addrs_mults.iter().collect(),
70 Instruction::HintExt2Felts(HintExt2FeltsInstr {
71 output_addrs_mults,
72 input_addr: _, }) => output_addrs_mults.iter().collect(),
74 Instruction::HintAddCurve(instr) => {
75 let HintAddCurveInstr {
76 output_x_addrs_mults,
77 output_y_addrs_mults, .. } = instr.as_ref();
79 output_x_addrs_mults.iter().chain(output_y_addrs_mults.iter()).collect()
80 }
81 _ => vec![],
82 })
83 .collect::<Vec<_>>();
84
85 let nb_rows = accesses.len().div_ceil(NUM_VAR_MEM_ENTRIES_PER_ROW);
86 let padded_nb_rows = match program.fixed_log2_rows(self) {
87 Some(log2_rows) => 1 << log2_rows,
88 None => next_power_of_two(nb_rows, None),
89 };
90 let mut values = vec![F::zero(); padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS];
91
92 let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS;
94 values[..populate_len]
95 .par_chunks_mut(NUM_MEM_ACCESS_COLS)
96 .zip_eq(accesses)
97 .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult });
98
99 Some(RowMajorMatrix::new(values, NUM_MEM_PREPROCESSED_INIT_COLS))
100 }
101
102 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
103 }
105
106 fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
107 let mut rows = input
109 .mem_var_events
110 .chunks(NUM_VAR_MEM_ENTRIES_PER_ROW)
111 .map(|row_events| {
112 let mut row = [F::zero(); NUM_MEM_INIT_COLS];
113 let cols: &mut MemoryCols<_> = row.as_mut_slice().borrow_mut();
114 for (cell, vals) in zip(&mut cols.values, row_events) {
115 *cell = vals.inner;
116 }
117 row
118 })
119 .collect::<Vec<_>>();
120
121 pad_rows_fixed(&mut rows, || [F::zero(); NUM_MEM_INIT_COLS], input.fixed_log2_rows(self));
123
124 RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS)
126 }
127
128 fn included(&self, _record: &Self::Record) -> bool {
129 true
130 }
131
132 fn local_only(&self) -> bool {
133 true
134 }
135}
136
137impl<AB> Air<AB> for MemoryChip<AB::F>
138where
139 AB: SP1RecursionAirBuilder + PairBuilder,
140{
141 fn eval(&self, builder: &mut AB) {
142 let main = builder.main();
143 let local = main.row_slice(0);
144 let local: &MemoryCols<AB::Var> = (*local).borrow();
145 let prep = builder.preprocessed();
146 let prep_local = prep.row_slice(0);
147 let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
148
149 for (value, access) in zip(local.values, prep_local.accesses) {
150 builder.send_block(access.addr, value, access.mult);
151 }
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 #![allow(clippy::print_stdout)]
158
159 use p3_baby_bear::BabyBear;
160 use p3_field::AbstractField;
161 use p3_matrix::dense::RowMajorMatrix;
162
163 use super::*;
164
165 #[test]
166 pub fn generate_trace() {
167 let shard = ExecutionRecord::<BabyBear> {
168 mem_var_events: vec![
169 MemEvent { inner: BabyBear::one().into() },
170 MemEvent { inner: BabyBear::one().into() },
171 ],
172 ..Default::default()
173 };
174 let chip = MemoryChip::default();
175 let trace: RowMajorMatrix<BabyBear> =
176 chip.generate_trace(&shard, &mut ExecutionRecord::default());
177 println!("{:?}", trace.values)
178 }
179}