sp1_recursion_core/chips/
alu_ext.rs

1use core::borrow::Borrow;
2use p3_air::{Air, BaseAir, PairBuilder};
3use p3_baby_bear::BabyBear;
4use p3_field::{extension::BinomiallyExtendable, AbstractField, Field, PrimeField32};
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::next_power_of_two;
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::{ExtensionAirBuilder, MachineAir};
10use std::{borrow::BorrowMut, iter::zip};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14pub const NUM_EXT_ALU_ENTRIES_PER_ROW: usize = 4;
15
16#[derive(Default)]
17pub struct ExtAluChip;
18
19pub const NUM_EXT_ALU_COLS: usize = core::mem::size_of::<ExtAluCols<u8>>();
20
21#[derive(AlignedBorrow, Debug, Clone, Copy)]
22#[repr(C)]
23pub struct ExtAluCols<F: Copy> {
24    pub values: [ExtAluValueCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
25}
26const NUM_EXT_ALU_VALUE_COLS: usize = core::mem::size_of::<ExtAluValueCols<u8>>();
27
28#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct ExtAluValueCols<F: Copy> {
31    pub vals: ExtAluIo<Block<F>>,
32}
33
34pub const NUM_EXT_ALU_PREPROCESSED_COLS: usize = core::mem::size_of::<ExtAluPreprocessedCols<u8>>();
35
36#[derive(AlignedBorrow, Debug, Clone, Copy)]
37#[repr(C)]
38pub struct ExtAluPreprocessedCols<F: Copy> {
39    pub accesses: [ExtAluAccessCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
40}
41
42pub const NUM_EXT_ALU_ACCESS_COLS: usize = core::mem::size_of::<ExtAluAccessCols<u8>>();
43
44#[derive(AlignedBorrow, Debug, Clone, Copy)]
45#[repr(C)]
46pub struct ExtAluAccessCols<F: Copy> {
47    pub addrs: ExtAluIo<Address<F>>,
48    pub is_add: F,
49    pub is_sub: F,
50    pub is_mul: F,
51    pub is_div: F,
52    pub mult: F,
53}
54
55impl<F: Field> BaseAir<F> for ExtAluChip {
56    fn width(&self) -> usize {
57        NUM_EXT_ALU_COLS
58    }
59}
60
61impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for ExtAluChip {
62    type Record = ExecutionRecord<F>;
63
64    type Program = crate::RecursionProgram<F>;
65
66    fn name(&self) -> String {
67        "ExtAlu".to_string()
68    }
69
70    fn preprocessed_width(&self) -> usize {
71        NUM_EXT_ALU_PREPROCESSED_COLS
72    }
73
74    fn preprocessed_num_rows(&self, program: &Self::Program, instrs_len: usize) -> Option<usize> {
75        let nb_rows = instrs_len.div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
76        let fixed_log2_rows = program.fixed_log2_rows(self);
77        Some(match fixed_log2_rows {
78            Some(log2_rows) => 1 << log2_rows,
79            None => next_power_of_two(nb_rows, None),
80        })
81    }
82
83    fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
84        assert_eq!(
85            std::any::TypeId::of::<F>(),
86            std::any::TypeId::of::<BabyBear>(),
87            "generate_preprocessed_trace only supports BabyBear field"
88        );
89
90        let instrs = unsafe {
91            std::mem::transmute::<Vec<&ExtAluInstr<F>>, Vec<&ExtAluInstr<BabyBear>>>(
92                program
93                    .inner
94                    .iter()
95                    .filter_map(|instruction| match instruction {
96                        Instruction::ExtAlu(x) => Some(x),
97                        _ => None,
98                    })
99                    .collect::<Vec<_>>(),
100            )
101        };
102        let padded_nb_rows = self.preprocessed_num_rows(program, instrs.len()).unwrap();
103        let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS];
104
105        // Generate the trace rows & corresponding records for each chunk of events in parallel.
106        let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS;
107        values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
108            |(row, instr)| {
109                let access: &mut ExtAluAccessCols<_> = row.borrow_mut();
110                unsafe {
111                    crate::sys::alu_ext_instr_to_row_babybear(instr, access);
112                }
113            },
114        );
115
116        // Convert the trace to a row major matrix.
117        Some(RowMajorMatrix::new(
118            unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<F>>(values) },
119            NUM_EXT_ALU_PREPROCESSED_COLS,
120        ))
121    }
122
123    fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
124        // This is a no-op.
125    }
126
127    fn num_rows(&self, input: &Self::Record) -> Option<usize> {
128        let events = &input.ext_alu_events;
129        let nb_rows = events.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
130        let fixed_log2_rows = input.fixed_log2_rows(self);
131        Some(match fixed_log2_rows {
132            Some(log2_rows) => 1 << log2_rows,
133            None => next_power_of_two(nb_rows, None),
134        })
135    }
136
137    fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
138        assert_eq!(
139            std::any::TypeId::of::<F>(),
140            std::any::TypeId::of::<BabyBear>(),
141            "generate_trace only supports BabyBear field"
142        );
143
144        let events = unsafe {
145            std::mem::transmute::<&Vec<ExtAluIo<Block<F>>>, &Vec<ExtAluIo<Block<BabyBear>>>>(
146                &input.ext_alu_events,
147            )
148        };
149        let padded_nb_rows = self.num_rows(input).unwrap();
150        let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_COLS];
151
152        // Generate the trace rows & corresponding records for each chunk of events in parallel.
153        let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS;
154        values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each(
155            |(row, &vals)| {
156                let cols: &mut ExtAluValueCols<_> = row.borrow_mut();
157                unsafe {
158                    crate::sys::alu_ext_event_to_row_babybear(&vals, cols);
159                }
160            },
161        );
162
163        // Convert the trace to a row major matrix.
164        RowMajorMatrix::new(
165            unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<F>>(values) },
166            NUM_EXT_ALU_COLS,
167        )
168    }
169
170    fn included(&self, _record: &Self::Record) -> bool {
171        true
172    }
173
174    fn local_only(&self) -> bool {
175        true
176    }
177}
178
179impl<AB> Air<AB> for ExtAluChip
180where
181    AB: SP1RecursionAirBuilder + PairBuilder,
182{
183    fn eval(&self, builder: &mut AB) {
184        let main = builder.main();
185        let local = main.row_slice(0);
186        let local: &ExtAluCols<AB::Var> = (*local).borrow();
187        let prep = builder.preprocessed();
188        let prep_local = prep.row_slice(0);
189        let prep_local: &ExtAluPreprocessedCols<AB::Var> = (*prep_local).borrow();
190
191        for (
192            ExtAluValueCols { vals },
193            ExtAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
194        ) in zip(local.values, prep_local.accesses)
195        {
196            let in1 = vals.in1.as_extension::<AB>();
197            let in2 = vals.in2.as_extension::<AB>();
198            let out = vals.out.as_extension::<AB>();
199
200            // Check exactly one flag is enabled.
201            let is_real = is_add + is_sub + is_mul + is_div;
202            builder.assert_bool(is_real.clone());
203
204            builder.when(is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone());
205            builder.when(is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone());
206            builder.when(is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone());
207            builder.when(is_div).assert_ext_eq(in1, in2 * out);
208
209            // Read the inputs from memory.
210            builder.receive_block(addrs.in1, vals.in1, is_real.clone());
211
212            builder.receive_block(addrs.in2, vals.in2, is_real);
213
214            // Write the output to memory.
215            builder.send_block(addrs.out, vals.out, mult);
216        }
217    }
218}
219
220#[cfg(test)]
221mod tests {
222    use crate::{chips::test_fixtures, runtime::instruction as instr};
223    use machine::tests::test_recursion_linear_program;
224    use p3_baby_bear::BabyBear;
225    use p3_field::{extension::BinomialExtensionField, AbstractExtensionField, AbstractField};
226    use p3_matrix::dense::RowMajorMatrix;
227    use rand::{rngs::StdRng, Rng, SeedableRng};
228    use sp1_stark::StarkGenericConfig;
229    use stark::BabyBearPoseidon2Outer;
230
231    use super::*;
232
233    fn generate_trace_reference(
234        input: &ExecutionRecord<BabyBear>,
235        _: &mut ExecutionRecord<BabyBear>,
236    ) -> RowMajorMatrix<BabyBear> {
237        let events = &input.ext_alu_events;
238        let padded_nb_rows = ExtAluChip.num_rows(input).unwrap();
239        let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_COLS];
240
241        let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS;
242        values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each(
243            |(row, &vals)| {
244                let cols: &mut ExtAluValueCols<_> = row.borrow_mut();
245                *cols = ExtAluValueCols { vals };
246            },
247        );
248
249        RowMajorMatrix::new(values, NUM_EXT_ALU_COLS)
250    }
251
252    #[test]
253    fn generate_trace() {
254        let shard = test_fixtures::shard();
255        let mut execution_record = test_fixtures::default_execution_record();
256        let trace = ExtAluChip.generate_trace(&shard, &mut execution_record);
257        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
258
259        assert_eq!(trace, generate_trace_reference(&shard, &mut execution_record));
260    }
261
262    fn generate_preprocessed_trace_reference(
263        program: &RecursionProgram<BabyBear>,
264    ) -> RowMajorMatrix<BabyBear> {
265        type F = BabyBear;
266
267        let instrs = program
268            .inner
269            .iter()
270            .filter_map(|instruction| match instruction {
271                Instruction::ExtAlu(x) => Some(x),
272                _ => None,
273            })
274            .collect::<Vec<_>>();
275        let padded_nb_rows = ExtAluChip.preprocessed_num_rows(program, instrs.len()).unwrap();
276        let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS];
277
278        let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS;
279        values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
280            |(row, instr)| {
281                let ExtAluInstr { opcode, mult, addrs } = instr;
282                let access: &mut ExtAluAccessCols<_> = row.borrow_mut();
283                *access = ExtAluAccessCols {
284                    addrs: addrs.to_owned(),
285                    is_add: F::from_bool(false),
286                    is_sub: F::from_bool(false),
287                    is_mul: F::from_bool(false),
288                    is_div: F::from_bool(false),
289                    mult: mult.to_owned(),
290                };
291                let target_flag = match opcode {
292                    ExtAluOpcode::AddE => &mut access.is_add,
293                    ExtAluOpcode::SubE => &mut access.is_sub,
294                    ExtAluOpcode::MulE => &mut access.is_mul,
295                    ExtAluOpcode::DivE => &mut access.is_div,
296                };
297                *target_flag = F::from_bool(true);
298            },
299        );
300
301        RowMajorMatrix::new(values, NUM_EXT_ALU_PREPROCESSED_COLS)
302    }
303
304    #[test]
305    #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
306    fn generate_preprocessed_trace() {
307        let program = test_fixtures::program();
308        let trace = ExtAluChip.generate_preprocessed_trace(&program).unwrap();
309        assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
310
311        assert_eq!(trace, generate_preprocessed_trace_reference(&program));
312    }
313
314    #[test]
315    pub fn four_ops() {
316        type SC = BabyBearPoseidon2Outer;
317        type F = <SC as StarkGenericConfig>::Val;
318
319        let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
320        let mut random_extfelt = move || {
321            let inner: [F; 4] = core::array::from_fn(|_| rng.sample(rand::distributions::Standard));
322            BinomialExtensionField::<F, D>::from_base_slice(&inner)
323        };
324        let mut addr = 0;
325
326        let instructions = (0..1000)
327            .flat_map(|_| {
328                let quot = random_extfelt();
329                let in2 = random_extfelt();
330                let in1 = in2 * quot;
331                let alloc_size = 6;
332                let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
333                addr += alloc_size;
334                [
335                    instr::mem_ext(MemAccessKind::Write, 4, a[0], in1),
336                    instr::mem_ext(MemAccessKind::Write, 4, a[1], in2),
337                    instr::ext_alu(ExtAluOpcode::AddE, 1, a[2], a[0], a[1]),
338                    instr::mem_ext(MemAccessKind::Read, 1, a[2], in1 + in2),
339                    instr::ext_alu(ExtAluOpcode::SubE, 1, a[3], a[0], a[1]),
340                    instr::mem_ext(MemAccessKind::Read, 1, a[3], in1 - in2),
341                    instr::ext_alu(ExtAluOpcode::MulE, 1, a[4], a[0], a[1]),
342                    instr::mem_ext(MemAccessKind::Read, 1, a[4], in1 * in2),
343                    instr::ext_alu(ExtAluOpcode::DivE, 1, a[5], a[0], a[1]),
344                    instr::mem_ext(MemAccessKind::Read, 1, a[5], quot),
345                ]
346            })
347            .collect::<Vec<Instruction<F>>>();
348
349        test_recursion_linear_program(instructions);
350    }
351}