Skip to main content

sp1_recursion_machine/chips/
alu_base.rs

1use crate::builder::SP1RecursionAirBuilder;
2use core::borrow::Borrow;
3use slop_air::{Air, AirBuilder, BaseAir, PairBuilder};
4use slop_algebra::{Field, PrimeField32};
5use slop_matrix::Matrix;
6use slop_maybe_rayon::prelude::{IndexedParallelIterator, ParallelIterator, ParallelSliceMut};
7use sp1_derive::AlignedBorrow;
8use sp1_hypercube::{air::MachineAir, next_multiple_of_32};
9use sp1_primitives::SP1Field;
10use sp1_recursion_executor::{
11    Address, BaseAluInstr, BaseAluIo, BaseAluOpcode, ExecutionRecord, Instruction, RecursionProgram,
12};
13use std::{borrow::BorrowMut, iter::zip, mem::MaybeUninit};
14
15pub const NUM_BASE_ALU_ENTRIES_PER_ROW: usize = 1;
16
17#[derive(Default, Clone)]
18pub struct BaseAluChip;
19
20pub const NUM_BASE_ALU_COLS: usize = core::mem::size_of::<BaseAluCols<u8>>();
21
22#[derive(AlignedBorrow, Debug, Clone, Copy)]
23#[repr(C)]
24pub struct BaseAluCols<F: Copy> {
25    pub values: [BaseAluValueCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
26}
27
28pub const NUM_BASE_ALU_VALUE_COLS: usize = core::mem::size_of::<BaseAluValueCols<u8>>();
29
30#[derive(AlignedBorrow, Debug, Clone, Copy)]
31#[repr(C)]
32pub struct BaseAluValueCols<F: Copy> {
33    pub vals: BaseAluIo<F>,
34}
35
36pub const NUM_BASE_ALU_PREPROCESSED_COLS: usize =
37    core::mem::size_of::<BaseAluPreprocessedCols<u8>>();
38
39#[derive(AlignedBorrow, Debug, Clone, Copy)]
40#[repr(C)]
41pub struct BaseAluPreprocessedCols<F: Copy> {
42    pub accesses: [BaseAluAccessCols<F>; NUM_BASE_ALU_ENTRIES_PER_ROW],
43}
44
45pub const NUM_BASE_ALU_ACCESS_COLS: usize = core::mem::size_of::<BaseAluAccessCols<u8>>();
46
47#[derive(AlignedBorrow, Debug, Clone, Copy)]
48#[repr(C)]
49pub struct BaseAluAccessCols<F: Copy> {
50    pub addrs: BaseAluIo<Address<F>>,
51    pub is_add: F,
52    pub is_sub: F,
53    pub is_mul: F,
54    pub is_div: F,
55    pub mult: F,
56}
57
58impl<F: Field> BaseAir<F> for BaseAluChip {
59    fn width(&self) -> usize {
60        NUM_BASE_ALU_COLS
61    }
62}
63
64impl<F: PrimeField32> MachineAir<F> for BaseAluChip {
65    type Record = ExecutionRecord<F>;
66
67    type Program = RecursionProgram<F>;
68
69    fn name(&self) -> &'static str {
70        "BaseAlu"
71    }
72
73    fn preprocessed_width(&self) -> usize {
74        NUM_BASE_ALU_PREPROCESSED_COLS
75    }
76
77    fn preprocessed_num_rows(&self, program: &Self::Program) -> Option<usize> {
78        let instrs_len = program
79            .inner
80            .iter()
81            .filter_map(|instruction| match instruction.inner() {
82                Instruction::BaseAlu(x) => Some(x),
83                _ => None,
84            })
85            .count();
86        self.preprocessed_num_rows_with_instrs_len(program, instrs_len)
87    }
88
89    fn preprocessed_num_rows_with_instrs_len(
90        &self,
91        program: &Self::Program,
92        instrs_len: usize,
93    ) -> Option<usize> {
94        let height = program.shape.as_ref().and_then(|shape| shape.height(self));
95        let nb_rows = instrs_len.div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
96        Some(next_multiple_of_32(nb_rows, height))
97    }
98
99    fn generate_preprocessed_trace_into(
100        &self,
101        program: &Self::Program,
102        buffer: &mut [MaybeUninit<F>],
103    ) {
104        assert_eq!(
105            std::any::TypeId::of::<F>(),
106            std::any::TypeId::of::<SP1Field>(),
107            "generate_preprocessed_trace only supports SP1Field field"
108        );
109
110        let instrs = program
111            .inner
112            .iter()
113            .filter_map(|instruction| match instruction.inner() {
114                Instruction::BaseAlu(x) => Some(x),
115                _ => None,
116            })
117            .collect::<Vec<_>>();
118        let padded_nb_rows =
119            self.preprocessed_num_rows_with_instrs_len(program, instrs.len()).unwrap();
120
121        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
122        let values = unsafe {
123            core::slice::from_raw_parts_mut(
124                buffer_ptr,
125                padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS,
126            )
127        };
128
129        unsafe {
130            let padding_start = instrs.len() * NUM_BASE_ALU_ACCESS_COLS;
131            let padding_size = padded_nb_rows * NUM_BASE_ALU_PREPROCESSED_COLS - padding_start;
132            if padding_size > 0 {
133                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
134            }
135        }
136
137        // Generate the trace rows & corresponding records for each chunk of events in parallel.
138        let populate_len = instrs.len() * NUM_BASE_ALU_ACCESS_COLS;
139        values[..populate_len].par_chunks_mut(NUM_BASE_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
140            |(row, instr)| {
141                let BaseAluInstr { opcode, mult, addrs } = instr;
142                let access: &mut BaseAluAccessCols<_> = row.borrow_mut();
143                *access = BaseAluAccessCols {
144                    addrs: addrs.to_owned(),
145                    is_add: F::from_bool(false),
146                    is_sub: F::from_bool(false),
147                    is_mul: F::from_bool(false),
148                    is_div: F::from_bool(false),
149                    mult: mult.to_owned(),
150                };
151                let target_flag = match opcode {
152                    BaseAluOpcode::AddF => &mut access.is_add,
153                    BaseAluOpcode::SubF => &mut access.is_sub,
154                    BaseAluOpcode::MulF => &mut access.is_mul,
155                    BaseAluOpcode::DivF => &mut access.is_div,
156                };
157                *target_flag = F::from_bool(true);
158            },
159        );
160    }
161
162    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
163        // This is a no-op.
164    }
165
166    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
167        let height = input.program.shape.as_ref().and_then(|shape| shape.height(self));
168        let nb_rows = input.base_alu_events.len().div_ceil(NUM_BASE_ALU_ENTRIES_PER_ROW);
169        Some(next_multiple_of_32(nb_rows, height))
170    }
171
172    fn generate_trace_into(
173        &self,
174        input: &ExecutionRecord<F>,
175        _: &mut ExecutionRecord<F>,
176        buffer: &mut [MaybeUninit<F>],
177    ) {
178        assert_eq!(
179            std::any::TypeId::of::<F>(),
180            std::any::TypeId::of::<SP1Field>(),
181            "generate_trace_into only supports SP1Field"
182        );
183
184        let events = &input.base_alu_events;
185        let padded_nb_rows = self.num_rows(input).unwrap();
186        let num_event_rows = events.len();
187
188        unsafe {
189            let padding_start = num_event_rows * NUM_BASE_ALU_COLS;
190            let padding_size = (padded_nb_rows - num_event_rows) * NUM_BASE_ALU_COLS;
191            if padding_size > 0 {
192                core::ptr::write_bytes(buffer[padding_start..].as_mut_ptr(), 0, padding_size);
193            }
194        }
195
196        let buffer_ptr = buffer.as_mut_ptr() as *mut F;
197        let values = unsafe {
198            core::slice::from_raw_parts_mut(buffer_ptr, num_event_rows * NUM_BASE_ALU_COLS)
199        };
200
201        // Generate the trace rows & corresponding records for each chunk of events in parallel.
202        let populate_len = events.len() * NUM_BASE_ALU_VALUE_COLS;
203        values[..populate_len].par_chunks_mut(NUM_BASE_ALU_VALUE_COLS).zip_eq(events).for_each(
204            |(row, &vals)| {
205                let cols: &mut BaseAluValueCols<_> = row.borrow_mut();
206                *cols = BaseAluValueCols { vals };
207            },
208        );
209    }
210
211    fn included(&self, _record: &Self::Record) -> bool {
212        true
213    }
214}
215
216impl<AB> Air<AB> for BaseAluChip
217where
218    AB: SP1RecursionAirBuilder + PairBuilder,
219{
220    fn eval(&self, builder: &mut AB) {
221        let main = builder.main();
222        let local = main.row_slice(0);
223        let local: &BaseAluCols<AB::Var> = (*local).borrow();
224        let prep = builder.preprocessed();
225        let prep_local = prep.row_slice(0);
226        let prep_local: &BaseAluPreprocessedCols<AB::Var> = (*prep_local).borrow();
227
228        for (
229            BaseAluValueCols { vals: BaseAluIo { out, in1, in2 } },
230            BaseAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
231        ) in zip(local.values, prep_local.accesses)
232        {
233            // Check exactly one flag is enabled.
234            let is_real = is_add + is_sub + is_mul + is_div;
235            builder.assert_bool(is_real.clone());
236
237            builder.when(is_add).assert_eq(in1 + in2, out);
238            builder.when(is_sub).assert_eq(in1, in2 + out);
239            builder.when(is_mul).assert_eq(out, in1 * in2);
240            builder.when(is_div).assert_eq(in2 * out, in1);
241
242            // Read the inputs from memory.
243            builder.receive_single(addrs.in1, in1, is_real.clone());
244            builder.receive_single(addrs.in2, in2, is_real);
245
246            // Write the output to memory.
247            builder.send_single(addrs.out, out, mult);
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254
255    use rand::prelude::*;
256    use sp1_recursion_executor::{instruction as instr, BaseAluOpcode, MemAccessKind};
257
258    use crate::{chips::test_fixtures, test::test_recursion_linear_program};
259
260    use super::*;
261
262    #[tokio::test]
263    async fn generate_trace() {
264        let shard = test_fixtures::shard().await;
265        let trace = BaseAluChip.generate_trace(shard, &mut ExecutionRecord::default());
266        assert!(trace.height() > test_fixtures::MIN_ROWS);
267    }
268
269    #[tokio::test]
270    async fn generate_preprocessed_trace() {
271        let program = &test_fixtures::program_with_input().await.0;
272        let trace = BaseAluChip.generate_preprocessed_trace(program).unwrap();
273        assert!(trace.height() > test_fixtures::MIN_ROWS);
274    }
275
276    #[tokio::test]
277    pub async fn four_ops() {
278        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
279        let mut random_felt = move || -> SP1Field { rng.sample(rand::distributions::Standard) };
280        let mut addr = 0;
281
282        let instructions = (0..1000)
283            .flat_map(|_| {
284                let quot = random_felt();
285                let in2 = random_felt();
286                let in1 = in2 * quot;
287                let alloc_size = 6;
288                let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
289                addr += alloc_size;
290                [
291                    instr::mem_single(MemAccessKind::Write, 4, a[0], in1),
292                    instr::mem_single(MemAccessKind::Write, 4, a[1], in2),
293                    instr::base_alu(BaseAluOpcode::AddF, 1, a[2], a[0], a[1]),
294                    instr::mem_single(MemAccessKind::Read, 1, a[2], in1 + in2),
295                    instr::base_alu(BaseAluOpcode::SubF, 1, a[3], a[0], a[1]),
296                    instr::mem_single(MemAccessKind::Read, 1, a[3], in1 - in2),
297                    instr::base_alu(BaseAluOpcode::MulF, 1, a[4], a[0], a[1]),
298                    instr::mem_single(MemAccessKind::Read, 1, a[4], in1 * in2),
299                    instr::base_alu(BaseAluOpcode::DivF, 1, a[5], a[0], a[1]),
300                    instr::mem_single(MemAccessKind::Read, 1, a[5], quot),
301                ]
302            })
303            .collect::<Vec<Instruction<SP1Field>>>();
304
305        test_recursion_linear_program(instructions).await;
306    }
307}