sp1_recursion_core/chips/mem/
variable.rs

1use core::borrow::Borrow;
2use instruction::{HintAddCurveInstr, HintBitsInstr, HintExt2FeltsInstr, HintInstr};
3use p3_air::{Air, BaseAir, PairBuilder};
4use p3_field::PrimeField32;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::{next_power_of_two, pad_rows_fixed};
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::MachineAir;
10use std::{borrow::BorrowMut, iter::zip, marker::PhantomData};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14use super::{MemoryAccessCols, NUM_MEM_ACCESS_COLS};
15
16pub const NUM_VAR_MEM_ENTRIES_PER_ROW: usize = 2;
17
18#[derive(Default)]
19pub struct MemoryChip<F> {
20    _marker: PhantomData<F>,
21}
22
23pub const NUM_MEM_INIT_COLS: usize = core::mem::size_of::<MemoryCols<u8>>();
24
25#[derive(AlignedBorrow, Debug, Clone, Copy)]
26#[repr(C)]
27pub struct MemoryCols<F: Copy> {
28    values: [Block<F>; NUM_VAR_MEM_ENTRIES_PER_ROW],
29}
30
31pub const NUM_MEM_PREPROCESSED_INIT_COLS: usize =
32    core::mem::size_of::<MemoryPreprocessedCols<u8>>();
33
34#[derive(AlignedBorrow, Debug, Clone, Copy)]
35#[repr(C)]
36pub struct MemoryPreprocessedCols<F: Copy> {
37    accesses: [MemoryAccessCols<F>; NUM_VAR_MEM_ENTRIES_PER_ROW],
38}
39
40impl<F: Send + Sync> BaseAir<F> for MemoryChip<F> {
41    fn width(&self) -> usize {
42        NUM_MEM_INIT_COLS
43    }
44}
45
46impl<F: PrimeField32> MachineAir<F> for MemoryChip<F> {
47    type Record = crate::ExecutionRecord<F>;
48
49    type Program = crate::RecursionProgram<F>;
50
51    fn name(&self) -> String {
52        "MemoryVar".to_string()
53    }
54    fn preprocessed_width(&self) -> usize {
55        NUM_MEM_PREPROCESSED_INIT_COLS
56    }
57
58    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
59        // Allocating an intermediate `Vec` is faster.
60        let accesses = program
61            .inner
62            .iter()
63            // .par_bridge() // Using `rayon` here provides a big speedup. TODO put rayon back
64            .flat_map(|instruction| match instruction {
65                Instruction::Hint(HintInstr { output_addrs_mults }) |
66                Instruction::HintBits(HintBitsInstr {
67                    output_addrs_mults,
68                    input_addr: _, // No receive interaction for the hint operation
69                }) => output_addrs_mults.iter().collect(),
70                Instruction::HintExt2Felts(HintExt2FeltsInstr {
71                    output_addrs_mults,
72                    input_addr: _, // No receive interaction for the hint operation
73                }) => output_addrs_mults.iter().collect(),
74                Instruction::HintAddCurve(instr) => {
75                    let HintAddCurveInstr {
76                        output_x_addrs_mults,
77                        output_y_addrs_mults, .. // No receive interaction for the hint operation
78                    } = instr.as_ref();
79                    output_x_addrs_mults.iter().chain(output_y_addrs_mults.iter()).collect()
80                }
81                _ => vec![],
82            })
83            .collect::<Vec<_>>();
84
85        let nb_rows = accesses.len().div_ceil(NUM_VAR_MEM_ENTRIES_PER_ROW);
86        let padded_nb_rows = match program.fixed_log2_rows(self) {
87            Some(log2_rows) => 1 << log2_rows,
88            None => next_power_of_two(nb_rows, None),
89        };
90        let mut values = vec![F::zero(); padded_nb_rows * NUM_MEM_PREPROCESSED_INIT_COLS];
91
92        // Generate the trace rows & corresponding records for each chunk of events in parallel.
93        let populate_len = accesses.len() * NUM_MEM_ACCESS_COLS;
94        values[..populate_len]
95            .par_chunks_mut(NUM_MEM_ACCESS_COLS)
96            .zip_eq(accesses)
97            .for_each(|(row, &(addr, mult))| *row.borrow_mut() = MemoryAccessCols { addr, mult });
98
99        Some(RowMajorMatrix::new(values, NUM_MEM_PREPROCESSED_INIT_COLS))
100    }
101
102    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
103        // This is a no-op.
104    }
105
106    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
107        // Generate the trace rows & corresponding records for each chunk of events in parallel.
108        let mut rows = input
109            .mem_var_events
110            .chunks(NUM_VAR_MEM_ENTRIES_PER_ROW)
111            .map(|row_events| {
112                let mut row = [F::zero(); NUM_MEM_INIT_COLS];
113                let cols: &mut MemoryCols<_> = row.as_mut_slice().borrow_mut();
114                for (cell, vals) in zip(&mut cols.values, row_events) {
115                    *cell = vals.inner;
116                }
117                row
118            })
119            .collect::<Vec<_>>();
120
121        // Pad the rows to the next power of two.
122        pad_rows_fixed(&mut rows, || [F::zero(); NUM_MEM_INIT_COLS], input.fixed_log2_rows(self));
123
124        // Convert the trace to a row major matrix.
125        RowMajorMatrix::new(rows.into_iter().flatten().collect::<Vec<_>>(), NUM_MEM_INIT_COLS)
126    }
127
128    fn included(&self, _record: &Self::Record) -> bool {
129        true
130    }
131
132    fn local_only(&self) -> bool {
133        true
134    }
135}
136
137impl<AB> Air<AB> for MemoryChip<AB::F>
138where
139    AB: SP1RecursionAirBuilder + PairBuilder,
140{
141    fn eval(&self, builder: &mut AB) {
142        let main = builder.main();
143        let local = main.row_slice(0);
144        let local: &MemoryCols<AB::Var> = (*local).borrow();
145        let prep = builder.preprocessed();
146        let prep_local = prep.row_slice(0);
147        let prep_local: &MemoryPreprocessedCols<AB::Var> = (*prep_local).borrow();
148
149        for (value, access) in zip(local.values, prep_local.accesses) {
150            builder.send_block(access.addr, value, access.mult);
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    #![allow(clippy::print_stdout)]
158
159    use p3_baby_bear::BabyBear;
160    use p3_field::AbstractField;
161    use p3_matrix::dense::RowMajorMatrix;
162
163    use super::*;
164
165    #[test]
166    pub fn generate_trace() {
167        let shard = ExecutionRecord::<BabyBear> {
168            mem_var_events: vec![
169                MemEvent { inner: BabyBear::one().into() },
170                MemEvent { inner: BabyBear::one().into() },
171            ],
172            ..Default::default()
173        };
174        let chip = MemoryChip::default();
175        let trace: RowMajorMatrix<BabyBear> =
176            chip.generate_trace(&shard, &mut ExecutionRecord::default());
177        println!("{:?}", trace.values)
178    }
179}