1use core::borrow::Borrow;
2use p3_air::{Air, BaseAir, PairBuilder};
3use p3_baby_bear::BabyBear;
4use p3_field::{extension::BinomiallyExtendable, AbstractField, Field, PrimeField32};
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6use p3_maybe_rayon::prelude::*;
7use sp1_core_machine::utils::next_power_of_two;
8use sp1_derive::AlignedBorrow;
9use sp1_stark::air::{ExtensionAirBuilder, MachineAir};
10use std::{borrow::BorrowMut, iter::zip};
11
12use crate::{builder::SP1RecursionAirBuilder, *};
13
14pub const NUM_EXT_ALU_ENTRIES_PER_ROW: usize = 4;
15
16#[derive(Default)]
17pub struct ExtAluChip;
18
19pub const NUM_EXT_ALU_COLS: usize = core::mem::size_of::<ExtAluCols<u8>>();
20
21#[derive(AlignedBorrow, Debug, Clone, Copy)]
22#[repr(C)]
23pub struct ExtAluCols<F: Copy> {
24 pub values: [ExtAluValueCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
25}
26const NUM_EXT_ALU_VALUE_COLS: usize = core::mem::size_of::<ExtAluValueCols<u8>>();
27
28#[derive(AlignedBorrow, Debug, Clone, Copy)]
29#[repr(C)]
30pub struct ExtAluValueCols<F: Copy> {
31 pub vals: ExtAluIo<Block<F>>,
32}
33
34pub const NUM_EXT_ALU_PREPROCESSED_COLS: usize = core::mem::size_of::<ExtAluPreprocessedCols<u8>>();
35
36#[derive(AlignedBorrow, Debug, Clone, Copy)]
37#[repr(C)]
38pub struct ExtAluPreprocessedCols<F: Copy> {
39 pub accesses: [ExtAluAccessCols<F>; NUM_EXT_ALU_ENTRIES_PER_ROW],
40}
41
42pub const NUM_EXT_ALU_ACCESS_COLS: usize = core::mem::size_of::<ExtAluAccessCols<u8>>();
43
44#[derive(AlignedBorrow, Debug, Clone, Copy)]
45#[repr(C)]
46pub struct ExtAluAccessCols<F: Copy> {
47 pub addrs: ExtAluIo<Address<F>>,
48 pub is_add: F,
49 pub is_sub: F,
50 pub is_mul: F,
51 pub is_div: F,
52 pub mult: F,
53}
54
55impl<F: Field> BaseAir<F> for ExtAluChip {
56 fn width(&self) -> usize {
57 NUM_EXT_ALU_COLS
58 }
59}
60
61impl<F: PrimeField32 + BinomiallyExtendable<D>> MachineAir<F> for ExtAluChip {
62 type Record = ExecutionRecord<F>;
63
64 type Program = crate::RecursionProgram<F>;
65
66 fn name(&self) -> String {
67 "ExtAlu".to_string()
68 }
69
70 fn preprocessed_width(&self) -> usize {
71 NUM_EXT_ALU_PREPROCESSED_COLS
72 }
73
74 fn preprocessed_num_rows(&self, program: &Self::Program, instrs_len: usize) -> Option<usize> {
75 let nb_rows = instrs_len.div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
76 let fixed_log2_rows = program.fixed_log2_rows(self);
77 Some(match fixed_log2_rows {
78 Some(log2_rows) => 1 << log2_rows,
79 None => next_power_of_two(nb_rows, None),
80 })
81 }
82
83 fn generate_preprocessed_trace(&self, program: &Self::Program) -> Option<RowMajorMatrix<F>> {
84 assert_eq!(
85 std::any::TypeId::of::<F>(),
86 std::any::TypeId::of::<BabyBear>(),
87 "generate_preprocessed_trace only supports BabyBear field"
88 );
89
90 let instrs = unsafe {
91 std::mem::transmute::<Vec<&ExtAluInstr<F>>, Vec<&ExtAluInstr<BabyBear>>>(
92 program
93 .inner
94 .iter()
95 .filter_map(|instruction| match instruction {
96 Instruction::ExtAlu(x) => Some(x),
97 _ => None,
98 })
99 .collect::<Vec<_>>(),
100 )
101 };
102 let padded_nb_rows = self.preprocessed_num_rows(program, instrs.len()).unwrap();
103 let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS];
104
105 let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS;
107 values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
108 |(row, instr)| {
109 let access: &mut ExtAluAccessCols<_> = row.borrow_mut();
110 unsafe {
111 crate::sys::alu_ext_instr_to_row_babybear(instr, access);
112 }
113 },
114 );
115
116 Some(RowMajorMatrix::new(
118 unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<F>>(values) },
119 NUM_EXT_ALU_PREPROCESSED_COLS,
120 ))
121 }
122
123 fn generate_dependencies(&self, _: &Self::Record, _: &mut Self::Record) {
124 }
126
127 fn num_rows(&self, input: &Self::Record) -> Option<usize> {
128 let events = &input.ext_alu_events;
129 let nb_rows = events.len().div_ceil(NUM_EXT_ALU_ENTRIES_PER_ROW);
130 let fixed_log2_rows = input.fixed_log2_rows(self);
131 Some(match fixed_log2_rows {
132 Some(log2_rows) => 1 << log2_rows,
133 None => next_power_of_two(nb_rows, None),
134 })
135 }
136
137 fn generate_trace(&self, input: &Self::Record, _: &mut Self::Record) -> RowMajorMatrix<F> {
138 assert_eq!(
139 std::any::TypeId::of::<F>(),
140 std::any::TypeId::of::<BabyBear>(),
141 "generate_trace only supports BabyBear field"
142 );
143
144 let events = unsafe {
145 std::mem::transmute::<&Vec<ExtAluIo<Block<F>>>, &Vec<ExtAluIo<Block<BabyBear>>>>(
146 &input.ext_alu_events,
147 )
148 };
149 let padded_nb_rows = self.num_rows(input).unwrap();
150 let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_COLS];
151
152 let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS;
154 values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each(
155 |(row, &vals)| {
156 let cols: &mut ExtAluValueCols<_> = row.borrow_mut();
157 unsafe {
158 crate::sys::alu_ext_event_to_row_babybear(&vals, cols);
159 }
160 },
161 );
162
163 RowMajorMatrix::new(
165 unsafe { std::mem::transmute::<Vec<BabyBear>, Vec<F>>(values) },
166 NUM_EXT_ALU_COLS,
167 )
168 }
169
170 fn included(&self, _record: &Self::Record) -> bool {
171 true
172 }
173
174 fn local_only(&self) -> bool {
175 true
176 }
177}
178
179impl<AB> Air<AB> for ExtAluChip
180where
181 AB: SP1RecursionAirBuilder + PairBuilder,
182{
183 fn eval(&self, builder: &mut AB) {
184 let main = builder.main();
185 let local = main.row_slice(0);
186 let local: &ExtAluCols<AB::Var> = (*local).borrow();
187 let prep = builder.preprocessed();
188 let prep_local = prep.row_slice(0);
189 let prep_local: &ExtAluPreprocessedCols<AB::Var> = (*prep_local).borrow();
190
191 for (
192 ExtAluValueCols { vals },
193 ExtAluAccessCols { addrs, is_add, is_sub, is_mul, is_div, mult },
194 ) in zip(local.values, prep_local.accesses)
195 {
196 let in1 = vals.in1.as_extension::<AB>();
197 let in2 = vals.in2.as_extension::<AB>();
198 let out = vals.out.as_extension::<AB>();
199
200 let is_real = is_add + is_sub + is_mul + is_div;
202 builder.assert_bool(is_real.clone());
203
204 builder.when(is_add).assert_ext_eq(in1.clone() + in2.clone(), out.clone());
205 builder.when(is_sub).assert_ext_eq(in1.clone(), in2.clone() + out.clone());
206 builder.when(is_mul).assert_ext_eq(in1.clone() * in2.clone(), out.clone());
207 builder.when(is_div).assert_ext_eq(in1, in2 * out);
208
209 builder.receive_block(addrs.in1, vals.in1, is_real.clone());
211
212 builder.receive_block(addrs.in2, vals.in2, is_real);
213
214 builder.send_block(addrs.out, vals.out, mult);
216 }
217 }
218}
219
220#[cfg(test)]
221mod tests {
222 use crate::{chips::test_fixtures, runtime::instruction as instr};
223 use machine::tests::test_recursion_linear_program;
224 use p3_baby_bear::BabyBear;
225 use p3_field::{extension::BinomialExtensionField, AbstractExtensionField, AbstractField};
226 use p3_matrix::dense::RowMajorMatrix;
227 use rand::{rngs::StdRng, Rng, SeedableRng};
228 use sp1_stark::StarkGenericConfig;
229 use stark::BabyBearPoseidon2Outer;
230
231 use super::*;
232
233 fn generate_trace_reference(
234 input: &ExecutionRecord<BabyBear>,
235 _: &mut ExecutionRecord<BabyBear>,
236 ) -> RowMajorMatrix<BabyBear> {
237 let events = &input.ext_alu_events;
238 let padded_nb_rows = ExtAluChip.num_rows(input).unwrap();
239 let mut values = vec![BabyBear::zero(); padded_nb_rows * NUM_EXT_ALU_COLS];
240
241 let populate_len = events.len() * NUM_EXT_ALU_VALUE_COLS;
242 values[..populate_len].par_chunks_mut(NUM_EXT_ALU_VALUE_COLS).zip_eq(events).for_each(
243 |(row, &vals)| {
244 let cols: &mut ExtAluValueCols<_> = row.borrow_mut();
245 *cols = ExtAluValueCols { vals };
246 },
247 );
248
249 RowMajorMatrix::new(values, NUM_EXT_ALU_COLS)
250 }
251
252 #[test]
253 fn generate_trace() {
254 let shard = test_fixtures::shard();
255 let mut execution_record = test_fixtures::default_execution_record();
256 let trace = ExtAluChip.generate_trace(&shard, &mut execution_record);
257 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
258
259 assert_eq!(trace, generate_trace_reference(&shard, &mut execution_record));
260 }
261
262 fn generate_preprocessed_trace_reference(
263 program: &RecursionProgram<BabyBear>,
264 ) -> RowMajorMatrix<BabyBear> {
265 type F = BabyBear;
266
267 let instrs = program
268 .inner
269 .iter()
270 .filter_map(|instruction| match instruction {
271 Instruction::ExtAlu(x) => Some(x),
272 _ => None,
273 })
274 .collect::<Vec<_>>();
275 let padded_nb_rows = ExtAluChip.preprocessed_num_rows(program, instrs.len()).unwrap();
276 let mut values = vec![F::zero(); padded_nb_rows * NUM_EXT_ALU_PREPROCESSED_COLS];
277
278 let populate_len = instrs.len() * NUM_EXT_ALU_ACCESS_COLS;
279 values[..populate_len].par_chunks_mut(NUM_EXT_ALU_ACCESS_COLS).zip_eq(instrs).for_each(
280 |(row, instr)| {
281 let ExtAluInstr { opcode, mult, addrs } = instr;
282 let access: &mut ExtAluAccessCols<_> = row.borrow_mut();
283 *access = ExtAluAccessCols {
284 addrs: addrs.to_owned(),
285 is_add: F::from_bool(false),
286 is_sub: F::from_bool(false),
287 is_mul: F::from_bool(false),
288 is_div: F::from_bool(false),
289 mult: mult.to_owned(),
290 };
291 let target_flag = match opcode {
292 ExtAluOpcode::AddE => &mut access.is_add,
293 ExtAluOpcode::SubE => &mut access.is_sub,
294 ExtAluOpcode::MulE => &mut access.is_mul,
295 ExtAluOpcode::DivE => &mut access.is_div,
296 };
297 *target_flag = F::from_bool(true);
298 },
299 );
300
301 RowMajorMatrix::new(values, NUM_EXT_ALU_PREPROCESSED_COLS)
302 }
303
304 #[test]
305 #[ignore = "Failing due to merge conflicts. Will be fixed shortly."]
306 fn generate_preprocessed_trace() {
307 let program = test_fixtures::program();
308 let trace = ExtAluChip.generate_preprocessed_trace(&program).unwrap();
309 assert!(trace.height() >= test_fixtures::MIN_TEST_CASES);
310
311 assert_eq!(trace, generate_preprocessed_trace_reference(&program));
312 }
313
314 #[test]
315 pub fn four_ops() {
316 type SC = BabyBearPoseidon2Outer;
317 type F = <SC as StarkGenericConfig>::Val;
318
319 let mut rng = StdRng::seed_from_u64(0xDEADBEEF);
320 let mut random_extfelt = move || {
321 let inner: [F; 4] = core::array::from_fn(|_| rng.sample(rand::distributions::Standard));
322 BinomialExtensionField::<F, D>::from_base_slice(&inner)
323 };
324 let mut addr = 0;
325
326 let instructions = (0..1000)
327 .flat_map(|_| {
328 let quot = random_extfelt();
329 let in2 = random_extfelt();
330 let in1 = in2 * quot;
331 let alloc_size = 6;
332 let a = (0..alloc_size).map(|x| x + addr).collect::<Vec<_>>();
333 addr += alloc_size;
334 [
335 instr::mem_ext(MemAccessKind::Write, 4, a[0], in1),
336 instr::mem_ext(MemAccessKind::Write, 4, a[1], in2),
337 instr::ext_alu(ExtAluOpcode::AddE, 1, a[2], a[0], a[1]),
338 instr::mem_ext(MemAccessKind::Read, 1, a[2], in1 + in2),
339 instr::ext_alu(ExtAluOpcode::SubE, 1, a[3], a[0], a[1]),
340 instr::mem_ext(MemAccessKind::Read, 1, a[3], in1 - in2),
341 instr::ext_alu(ExtAluOpcode::MulE, 1, a[4], a[0], a[1]),
342 instr::mem_ext(MemAccessKind::Read, 1, a[4], in1 * in2),
343 instr::ext_alu(ExtAluOpcode::DivE, 1, a[5], a[0], a[1]),
344 instr::mem_ext(MemAccessKind::Read, 1, a[5], quot),
345 ]
346 })
347 .collect::<Vec<Instruction<F>>>();
348
349 test_recursion_linear_program(instructions);
350 }
351}